|
6 | 6 | from mellea.backends.huggingface import LocalHFBackend |
7 | 7 | from mellea.stdlib.components import Document, Message |
8 | 8 | from mellea.stdlib.context import ChatContext |
9 | | -from mellea.stdlib.requirements.rag import CitationRequirement |
| 9 | +from mellea.stdlib.requirements.rag import CitationMode, CitationRequirement |
10 | 10 |
|
11 | 11 |
|
12 | 12 | @pytest.mark.huggingface |
@@ -269,18 +269,137 @@ async def test_citation_requirement_threshold_boundary(): |
269 | 269 | "mellea.stdlib.components.intrinsic.rag.find_citations", |
270 | 270 | return_value=mock_citations, |
271 | 271 | ): |
272 | | - # Test at exact threshold (0.8) |
273 | | - req = CitationRequirement(min_citation_coverage=0.8) |
| 272 | + # Test at exact threshold (0.8) with CHARACTERS mode |
| 273 | + req = CitationRequirement( |
| 274 | + min_citation_coverage=0.8, mode=CitationMode.CHARACTERS |
| 275 | + ) |
274 | 276 | result = await req.validate(backend, ctx) |
275 | 277 |
|
276 | 278 | # At exact threshold, should pass (>= comparison) |
277 | 279 | assert result.as_bool() |
278 | 280 | assert result.score == 0.8 |
279 | 281 |
|
280 | | - # Test just below threshold (0.81) |
281 | | - req_above = CitationRequirement(min_citation_coverage=0.81) |
| 282 | + # Test just below threshold (0.81) with CHARACTERS mode |
| 283 | + req_above = CitationRequirement( |
| 284 | + min_citation_coverage=0.81, mode=CitationMode.CHARACTERS |
| 285 | + ) |
282 | 286 | result_above = await req_above.validate(backend, ctx) |
283 | 287 |
|
284 | 288 | # Just below threshold, should fail |
285 | 289 | assert not result_above.as_bool() |
286 | 290 | assert result_above.score == 0.8 |
| 291 | + |
| 292 | + |
| 293 | +async def test_citation_requirement_claims_mode(): |
| 294 | + """Test citation requirement with CLAIMS mode.""" |
| 295 | + from unittest.mock import Mock, patch |
| 296 | + |
| 297 | + backend = Mock(spec=LocalHFBackend) |
| 298 | + |
| 299 | + # Create documents |
| 300 | + docs = [ |
| 301 | + Document(doc_id="doc1", text="The sky is blue."), |
| 302 | + Document(doc_id="doc2", text="Grass is green."), |
| 303 | + ] |
| 304 | + |
| 305 | + # Create a response with 3 sentences |
| 306 | + response = "The sky is blue. Grass is green. Water is wet." |
| 307 | + |
| 308 | + # Create context |
| 309 | + ctx = ChatContext().add(Message("user", "Tell me some facts.")) |
| 310 | + ctx = ctx.add(Message("assistant", response, documents=docs)) |
| 311 | + |
| 312 | + # Mock find_citations to return 2 citations (2 out of 3 claims = 66.7%) |
| 313 | + mock_citations = [ |
| 314 | + { |
| 315 | + "response_begin": 0, |
| 316 | + "response_end": 16, |
| 317 | + "response_text": "The sky is blue", |
| 318 | + "citation_doc_id": "doc1", |
| 319 | + "citation_text": "The sky is blue.", |
| 320 | + }, |
| 321 | + { |
| 322 | + "response_begin": 18, |
| 323 | + "response_end": 34, |
| 324 | + "response_text": "Grass is green", |
| 325 | + "citation_doc_id": "doc2", |
| 326 | + "citation_text": "Grass is green.", |
| 327 | + }, |
| 328 | + ] |
| 329 | + |
| 330 | + with patch( |
| 331 | + "mellea.stdlib.components.intrinsic.rag.find_citations", |
| 332 | + return_value=mock_citations, |
| 333 | + ): |
| 334 | + # Test with CLAIMS mode (default) - 2 citations out of 3 sentences = 66.7% |
| 335 | + req = CitationRequirement(min_citation_coverage=0.6) |
| 336 | + result = await req.validate(backend, ctx) |
| 337 | + |
| 338 | + # Should pass (66.7% >= 60%) |
| 339 | + assert result.as_bool() |
| 340 | + assert result.score is not None |
| 341 | + assert abs(result.score - 0.667) < 0.01 # Allow small floating point error |
| 342 | + assert result.reason is not None |
| 343 | + assert "claims" in result.reason |
| 344 | + |
| 345 | + # Test with higher threshold that should fail |
| 346 | + req_high = CitationRequirement(min_citation_coverage=0.7) |
| 347 | + result_high = await req_high.validate(backend, ctx) |
| 348 | + |
| 349 | + # Should fail (66.7% < 70%) |
| 350 | + assert not result_high.as_bool() |
| 351 | + assert result_high.score is not None |
| 352 | + assert abs(result_high.score - 0.667) < 0.01 |
| 353 | + |
| 354 | + |
| 355 | +async def test_citation_requirement_characters_vs_claims(): |
| 356 | + """Test that CHARACTERS and CLAIMS modes produce different results.""" |
| 357 | + from unittest.mock import Mock, patch |
| 358 | + |
| 359 | + backend = Mock(spec=LocalHFBackend) |
| 360 | + |
| 361 | + # Create documents |
| 362 | + docs = [Document(doc_id="doc1", text="Short fact.")] |
| 363 | + |
| 364 | + # Create a response: 1 short sentence with citation, 1 long sentence without |
| 365 | + response = "Short. This is a much longer sentence without any citation support." |
| 366 | + |
| 367 | + # Create context |
| 368 | + ctx = ChatContext().add(Message("user", "Tell me something.")) |
| 369 | + ctx = ctx.add(Message("assistant", response, documents=docs)) |
| 370 | + |
| 371 | + # Mock find_citations to return 1 citation for the short sentence |
| 372 | + mock_citations = [ |
| 373 | + { |
| 374 | + "response_begin": 0, |
| 375 | + "response_end": 6, # "Short." = 6 characters |
| 376 | + "response_text": "Short", |
| 377 | + "citation_doc_id": "doc1", |
| 378 | + "citation_text": "Short fact.", |
| 379 | + } |
| 380 | + ] |
| 381 | + |
| 382 | + with patch( |
| 383 | + "mellea.stdlib.components.intrinsic.rag.find_citations", |
| 384 | + return_value=mock_citations, |
| 385 | + ): |
| 386 | + # CLAIMS mode: 1 citation out of 2 sentences = 50% |
| 387 | + req_claims = CitationRequirement( |
| 388 | + min_citation_coverage=0.5, mode=CitationMode.CLAIMS |
| 389 | + ) |
| 390 | + result_claims = await req_claims.validate(backend, ctx) |
| 391 | + |
| 392 | + # CHARACTERS mode: 6 characters out of 67 total = ~9% |
| 393 | + req_chars = CitationRequirement( |
| 394 | + min_citation_coverage=0.5, mode=CitationMode.CHARACTERS |
| 395 | + ) |
| 396 | + result_chars = await req_chars.validate(backend, ctx) |
| 397 | + |
| 398 | + # CLAIMS mode should pass (50% >= 50%) |
| 399 | + assert result_claims.as_bool() |
| 400 | + assert result_claims.score == 0.5 |
| 401 | + |
| 402 | + # CHARACTERS mode should fail (~9% < 50%) |
| 403 | + assert not result_chars.as_bool() |
| 404 | + assert result_chars.score is not None |
| 405 | + assert result_chars.score < 0.1 |
0 commit comments