Skip to content

Commit 3fc5fbd

Browse files
committed
feat: review comments
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
1 parent 3f17c60 commit 3fc5fbd

3 files changed

Lines changed: 196 additions & 22 deletions

File tree

mellea/stdlib/requirements/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from ...core import Requirement, ValidationResult, default_output_to_bool
55
from .md import as_markdown_list, is_markdown_list, is_markdown_table
66
from .python_reqs import PythonExecutionReq
7-
from .rag import CitationRequirement
7+
from .rag import CitationMode, CitationRequirement
88
from .requirement import (
99
ALoraRequirement,
1010
LLMaJRequirement,
@@ -18,6 +18,7 @@
1818

1919
__all__ = [
2020
"ALoraRequirement",
21+
"CitationMode",
2122
"CitationRequirement",
2223
"LLMaJRequirement",
2324
"PythonExecutionReq",

mellea/stdlib/requirements/rag.py

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,27 @@
11
"""Requirements for RAG (Retrieval-Augmented Generation) workflows."""
22

33
from collections.abc import Iterable
4+
from enum import Enum
45

56
from ...backends.adapters import AdapterMixin
67
from ...core import Backend, Context, Requirement, ValidationResult
78
from ..components import Document, Message
89

910

11+
class CitationMode(Enum):
12+
"""Mode for calculating citation coverage.
13+
14+
Attributes:
15+
CLAIMS: Count the fraction of factual claims that have citations.
16+
Each citation record from find_citations represents one claim.
17+
CHARACTERS: Calculate the ratio of cited characters to total characters.
18+
Sums character ranges covered by citations.
19+
"""
20+
21+
CLAIMS = "claims"
22+
CHARACTERS = "characters"
23+
24+
1025
class CitationRequirement(Requirement):
1126
"""Requirement that validates RAG responses have adequate citation coverage.
1227
@@ -30,14 +45,19 @@ class CitationRequirement(Requirement):
3045
3146
Args:
3247
min_citation_coverage: Minimum ratio of cited content (0.0-1.0).
33-
The ratio of characters with citations to total response length
34-
must meet or exceed this threshold. Default is 0.8 (80% coverage).
48+
Interpretation depends on mode:
49+
- CLAIMS mode: fraction of factual claims with citations
50+
- CHARACTERS mode: ratio of cited characters to total characters
51+
Default is 0.8 (80% coverage).
3552
documents: Optional documents to validate against. Can be Document
3653
objects or strings (will be converted to Documents). If provided,
3754
these documents will be used instead of documents attached to
3855
messages in the context. Default is None (use context documents).
56+
mode: Citation coverage calculation mode. Default is CitationMode.CLAIMS
57+
(count fraction of claims with citations). Use CitationMode.CHARACTERS
58+
to calculate character-based coverage ratio instead.
3959
description: Custom description for the requirement. If None,
40-
generates a description based on coverage threshold.
60+
generates a description based on coverage threshold and mode.
4161
4262
Example:
4363
```python
@@ -64,6 +84,7 @@ def __init__(
6484
self,
6585
min_citation_coverage: float = 0.8,
6686
documents: Iterable[Document] | Iterable[str] | None = None,
87+
mode: CitationMode = CitationMode.CLAIMS,
6788
description: str | None = None,
6889
):
6990
"""Initialize citation coverage requirement."""
@@ -73,6 +94,7 @@ def __init__(
7394
)
7495

7596
self.min_citation_coverage = min_citation_coverage
97+
self.mode = mode
7698

7799
# Convert documents to Document objects if provided
78100
if documents is not None:
@@ -87,10 +109,16 @@ def __init__(
87109

88110
# Generate description if not provided
89111
if description is None:
90-
description = (
91-
f"Response must have adequate citation coverage "
92-
f"(minimum {min_citation_coverage * 100:.0f}% of content cited)"
93-
)
112+
if mode == CitationMode.CLAIMS:
113+
description = (
114+
f"Response must have adequate citation coverage "
115+
f"(minimum {min_citation_coverage * 100:.0f}% of factual claims cited)"
116+
)
117+
else: # CitationMode.CHARACTERS
118+
description = (
119+
f"Response must have adequate citation coverage "
120+
f"(minimum {min_citation_coverage * 100:.0f}% of characters cited)"
121+
)
94122

95123
# Initialize parent without validation function - we override validate() instead
96124
super().__init__(description=description, validation_fn=None)
@@ -198,13 +226,34 @@ async def validate(
198226
False, reason=f"Citation detection intrinsic failed: {e!s}"
199227
)
200228

201-
# Calculate citation coverage
202-
203-
cited_chars = sum(
204-
citation["response_end"] - citation["response_begin"]
205-
for citation in citations
206-
)
207-
coverage_ratio = cited_chars / total_chars
229+
# Calculate citation coverage based on mode
230+
if self.mode == CitationMode.CLAIMS:
231+
# Count fraction of claims (citation records) that exist
232+
# Each citation record represents a factual claim that has a citation
233+
# We need to estimate total claims in the response
234+
# For now, use a simple heuristic: split by sentence-ending punctuation
235+
import re
236+
237+
# Split response into sentences (simple heuristic)
238+
sentences = re.split(r"[.!?]+", response)
239+
# Filter out empty strings and whitespace-only strings
240+
sentences = [s.strip() for s in sentences if s.strip()]
241+
total_claims = len(sentences)
242+
243+
if total_claims == 0:
244+
# Edge case: no sentences detected
245+
coverage_ratio = 1.0 if len(citations) == 0 else 0.0
246+
else:
247+
# Number of claims with citations = number of citation records
248+
cited_claims = len(citations)
249+
coverage_ratio = cited_claims / total_claims
250+
else: # CitationMode.CHARACTERS
251+
# Calculate character-based coverage
252+
cited_chars = sum(
253+
citation["response_end"] - citation["response_begin"]
254+
for citation in citations
255+
)
256+
coverage_ratio = cited_chars / total_chars
208257

209258
# Check against min_citation_coverage
210259
passed = coverage_ratio >= self.min_citation_coverage
@@ -231,15 +280,20 @@ def _build_reason(
231280
coverage_pct = coverage_ratio * 100
232281
threshold_pct = self.min_citation_coverage * 100
233282

283+
if self.mode == CitationMode.CLAIMS:
284+
metric_name = "claims"
285+
else:
286+
metric_name = "characters"
287+
234288
if passed:
235289
reason = (
236290
f"Response has adequate citation coverage "
237-
f"({coverage_pct:.1f}% cited, threshold: {threshold_pct:.1f}%)"
291+
f"({coverage_pct:.1f}% of {metric_name} cited, threshold: {threshold_pct:.1f}%)"
238292
)
239293
else:
240294
reason = (
241295
f"Response has insufficient citation coverage "
242-
f"({coverage_pct:.1f}% cited, threshold: {threshold_pct:.1f}%)"
296+
f"({coverage_pct:.1f}% of {metric_name} cited, threshold: {threshold_pct:.1f}%)"
243297
)
244298

245299
# Add details about citations

test/stdlib/requirements/test_rag_requirements.py

Lines changed: 124 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mellea.backends.huggingface import LocalHFBackend
77
from mellea.stdlib.components import Document, Message
88
from mellea.stdlib.context import ChatContext
9-
from mellea.stdlib.requirements.rag import CitationRequirement
9+
from mellea.stdlib.requirements.rag import CitationMode, CitationRequirement
1010

1111

1212
@pytest.mark.huggingface
@@ -269,18 +269,137 @@ async def test_citation_requirement_threshold_boundary():
269269
"mellea.stdlib.components.intrinsic.rag.find_citations",
270270
return_value=mock_citations,
271271
):
272-
# Test at exact threshold (0.8)
273-
req = CitationRequirement(min_citation_coverage=0.8)
272+
# Test at exact threshold (0.8) with CHARACTERS mode
273+
req = CitationRequirement(
274+
min_citation_coverage=0.8, mode=CitationMode.CHARACTERS
275+
)
274276
result = await req.validate(backend, ctx)
275277

276278
# At exact threshold, should pass (>= comparison)
277279
assert result.as_bool()
278280
assert result.score == 0.8
279281

280-
# Test just below threshold (0.81)
281-
req_above = CitationRequirement(min_citation_coverage=0.81)
282+
# Test just below threshold (0.81) with CHARACTERS mode
283+
req_above = CitationRequirement(
284+
min_citation_coverage=0.81, mode=CitationMode.CHARACTERS
285+
)
282286
result_above = await req_above.validate(backend, ctx)
283287

284288
# Just below threshold, should fail
285289
assert not result_above.as_bool()
286290
assert result_above.score == 0.8
291+
292+
293+
async def test_citation_requirement_claims_mode():
294+
"""Test citation requirement with CLAIMS mode."""
295+
from unittest.mock import Mock, patch
296+
297+
backend = Mock(spec=LocalHFBackend)
298+
299+
# Create documents
300+
docs = [
301+
Document(doc_id="doc1", text="The sky is blue."),
302+
Document(doc_id="doc2", text="Grass is green."),
303+
]
304+
305+
# Create a response with 3 sentences
306+
response = "The sky is blue. Grass is green. Water is wet."
307+
308+
# Create context
309+
ctx = ChatContext().add(Message("user", "Tell me some facts."))
310+
ctx = ctx.add(Message("assistant", response, documents=docs))
311+
312+
# Mock find_citations to return 2 citations (2 out of 3 claims = 66.7%)
313+
mock_citations = [
314+
{
315+
"response_begin": 0,
316+
"response_end": 16,
317+
"response_text": "The sky is blue",
318+
"citation_doc_id": "doc1",
319+
"citation_text": "The sky is blue.",
320+
},
321+
{
322+
"response_begin": 18,
323+
"response_end": 34,
324+
"response_text": "Grass is green",
325+
"citation_doc_id": "doc2",
326+
"citation_text": "Grass is green.",
327+
},
328+
]
329+
330+
with patch(
331+
"mellea.stdlib.components.intrinsic.rag.find_citations",
332+
return_value=mock_citations,
333+
):
334+
# Test with CLAIMS mode (default) - 2 citations out of 3 sentences = 66.7%
335+
req = CitationRequirement(min_citation_coverage=0.6)
336+
result = await req.validate(backend, ctx)
337+
338+
# Should pass (66.7% >= 60%)
339+
assert result.as_bool()
340+
assert result.score is not None
341+
assert abs(result.score - 0.667) < 0.01 # Allow small floating point error
342+
assert result.reason is not None
343+
assert "claims" in result.reason
344+
345+
# Test with higher threshold that should fail
346+
req_high = CitationRequirement(min_citation_coverage=0.7)
347+
result_high = await req_high.validate(backend, ctx)
348+
349+
# Should fail (66.7% < 70%)
350+
assert not result_high.as_bool()
351+
assert result_high.score is not None
352+
assert abs(result_high.score - 0.667) < 0.01
353+
354+
355+
async def test_citation_requirement_characters_vs_claims():
356+
"""Test that CHARACTERS and CLAIMS modes produce different results."""
357+
from unittest.mock import Mock, patch
358+
359+
backend = Mock(spec=LocalHFBackend)
360+
361+
# Create documents
362+
docs = [Document(doc_id="doc1", text="Short fact.")]
363+
364+
# Create a response: 1 short sentence with citation, 1 long sentence without
365+
response = "Short. This is a much longer sentence without any citation support."
366+
367+
# Create context
368+
ctx = ChatContext().add(Message("user", "Tell me something."))
369+
ctx = ctx.add(Message("assistant", response, documents=docs))
370+
371+
# Mock find_citations to return 1 citation for the short sentence
372+
mock_citations = [
373+
{
374+
"response_begin": 0,
375+
"response_end": 6, # "Short." = 6 characters
376+
"response_text": "Short",
377+
"citation_doc_id": "doc1",
378+
"citation_text": "Short fact.",
379+
}
380+
]
381+
382+
with patch(
383+
"mellea.stdlib.components.intrinsic.rag.find_citations",
384+
return_value=mock_citations,
385+
):
386+
# CLAIMS mode: 1 citation out of 2 sentences = 50%
387+
req_claims = CitationRequirement(
388+
min_citation_coverage=0.5, mode=CitationMode.CLAIMS
389+
)
390+
result_claims = await req_claims.validate(backend, ctx)
391+
392+
# CHARACTERS mode: 6 characters out of 67 total = ~9%
393+
req_chars = CitationRequirement(
394+
min_citation_coverage=0.5, mode=CitationMode.CHARACTERS
395+
)
396+
result_chars = await req_chars.validate(backend, ctx)
397+
398+
# CLAIMS mode should pass (50% >= 50%)
399+
assert result_claims.as_bool()
400+
assert result_claims.score == 0.5
401+
402+
# CHARACTERS mode should fail (~9% < 50%)
403+
assert not result_chars.as_bool()
404+
assert result_chars.score is not None
405+
assert result_chars.score < 0.1

0 commit comments

Comments
 (0)