Skip to content

Commit 3f17c60

Browse files
committed
review comments
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
1 parent cc79c1e commit 3f17c60

4 files changed

Lines changed: 60 additions & 137 deletions

File tree

docs/examples/citation_requirement_example.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from mellea.backends.huggingface import LocalHFBackend
1414
from mellea.stdlib.components import Document, Message
1515
from mellea.stdlib.context import ChatContext
16-
from mellea.stdlib.requirements.rag import CitationRequirement, citation_check
16+
from mellea.stdlib.requirements.rag import CitationRequirement
1717

1818

1919
async def main():
@@ -51,8 +51,8 @@ async def main():
5151
ctx = ChatContext().add(Message("user", "What colors are the sky and grass?"))
5252
ctx = ctx.add(Message("assistant", response, documents=docs))
5353

54-
# Example 1: Using CitationRequirement directly
55-
print("\n--- Example 1: CitationRequirement with 70% coverage ---")
54+
# Example 1: Documents in constructor
55+
print("\n--- Example 1: CitationRequirement with documents in constructor ---")
5656
req = CitationRequirement(min_citation_coverage=0.7, documents=docs)
5757
result = await req.validate(backend, ctx)
5858

@@ -64,9 +64,9 @@ async def main():
6464
)
6565
print(f"Reason: {reason_preview}")
6666

67-
# Example 2: Using citation_check factory
68-
print("\n--- Example 2: Using citation_check factory ---")
69-
req2 = citation_check(docs, min_citation_coverage=0.8)
67+
# Example 2: Higher coverage threshold
68+
print("\n--- Example 2: Higher coverage threshold (80%) ---")
69+
req2 = CitationRequirement(min_citation_coverage=0.8, documents=docs)
7070
result2 = await req2.validate(backend, ctx)
7171

7272
print(f"Validation passed: {result2.as_bool()}")

mellea/stdlib/requirements/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from ...core import Requirement, ValidationResult, default_output_to_bool
55
from .md import as_markdown_list, is_markdown_list, is_markdown_table
66
from .python_reqs import PythonExecutionReq
7-
from .rag import CitationRequirement, citation_check
7+
from .rag import CitationRequirement
88
from .requirement import (
99
ALoraRequirement,
1010
LLMaJRequirement,
@@ -25,7 +25,6 @@
2525
"ValidationResult",
2626
"as_markdown_list",
2727
"check",
28-
"citation_check",
2928
"default_output_to_bool",
3029
"is_markdown_list",
3130
"is_markdown_table",

mellea/stdlib/requirements/rag.py

Lines changed: 5 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -159,34 +159,19 @@ async def validate(
159159
reason=f"Backend {backend.__class__.__name__} does not support adapters required for citation detection",
160160
)
161161

162-
# More specific check for HuggingFace backend
163-
try:
164-
from ...backends.huggingface import LocalHFBackend
165-
166-
if not isinstance(backend, LocalHFBackend):
167-
return ValidationResult(
168-
False,
169-
reason=f"Citation detection requires LocalHFBackend (HuggingFace), "
170-
f"but got {backend.__class__.__name__}. The find_citations intrinsic "
171-
f"only works with HuggingFace models.",
172-
)
173-
except ImportError:
174-
return ValidationResult(
175-
False,
176-
reason="HuggingFace backend not available. Please install mellea[hf] to use citation detection.",
177-
)
178-
179162
# Create context before the response by getting all but the last message
180163
all_messages = ctx.as_list()
181164
if len(all_messages) > 1:
182165
# Rebuild context without last message
166+
# Import here to avoid circular dependency
183167
from ..context import ChatContext
184168

185169
context_before_response = ChatContext()
186170
for msg in all_messages[:-1]:
187171
context_before_response = context_before_response.add(msg)
188172
else:
189173
# If only one message, use empty context
174+
# Import here to avoid circular dependency
190175
from ..context import ChatContext
191176

192177
context_before_response = ChatContext()
@@ -195,7 +180,9 @@ async def validate(
195180
total_chars = len(response)
196181
if total_chars == 0:
197182
return ValidationResult(
198-
True, reason="Empty response has 100% citation coverage", score=1.0
183+
True,
184+
reason="Empty response is considered to have adequate citation coverage",
185+
score=1.0,
199186
)
200187

201188
# Call find_citations intrinsic
@@ -283,57 +270,3 @@ def _build_reason(
283270
)
284271

285272
return reason
286-
287-
288-
def citation_check(
289-
documents: Iterable[Document] | Iterable[str],
290-
min_citation_coverage: float = 0.8,
291-
description: str | None = None,
292-
) -> CitationRequirement:
293-
"""Create a citation coverage requirement with pre-attached documents.
294-
295-
This is a convenience factory function that creates a CitationRequirement
296-
with documents already attached. This is useful when you have a fixed set of
297-
documents to validate against and want a cleaner API.
298-
299-
**Important**: This requirement requires a HuggingFace backend (LocalHFBackend).
300-
301-
Args:
302-
documents: Documents to check for citations. Can be Document objects
303-
or strings (will be converted to Documents).
304-
min_citation_coverage: Minimum ratio of cited content (0.0-1.0),
305-
defaults to 0.8 (80% coverage).
306-
description: Custom description for the requirement. If None,
307-
generates a description based on coverage threshold.
308-
309-
Returns:
310-
A CitationRequirement with documents attached
311-
312-
Example:
313-
```python
314-
from mellea.backends.huggingface import LocalHFBackend
315-
from mellea.stdlib.requirements.rag import citation_check
316-
from mellea.stdlib.components import Document
317-
318-
backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct")
319-
docs = [
320-
Document(doc_id="1", text="The sky is blue."),
321-
Document(doc_id="2", text="Grass is green.")
322-
]
323-
req = citation_check(docs, min_citation_coverage=0.8)
324-
325-
# Use with instruct() - no need to attach documents to messages
326-
result = m.instruct(
327-
"Answer: {{query}}",
328-
grounding_context={"query": "What color is the sky?"},
329-
requirements=[req],
330-
backend=backend,
331-
strategy=RejectionSamplingStrategy()
332-
)
333-
```
334-
"""
335-
return CitationRequirement(
336-
min_citation_coverage=min_citation_coverage,
337-
documents=documents,
338-
description=description,
339-
)

test/stdlib/requirements/test_rag_requirements.py

Lines changed: 48 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mellea.backends.huggingface import LocalHFBackend
77
from mellea.stdlib.components import Document, Message
88
from mellea.stdlib.context import ChatContext
9-
from mellea.stdlib.requirements.rag import CitationRequirement, citation_check
9+
from mellea.stdlib.requirements.rag import CitationRequirement
1010

1111

1212
@pytest.mark.huggingface
@@ -72,34 +72,6 @@ async def test_citation_requirement_with_constructor_documents():
7272
assert result.reason is not None
7373

7474

75-
@pytest.mark.huggingface
76-
@pytest.mark.llm
77-
@pytest.mark.requires_heavy_ram
78-
async def test_citation_check_factory():
79-
"""Test citation_check factory function."""
80-
backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro")
81-
82-
# Create documents
83-
docs = [Document(doc_id="doc1", text="The sky is blue during the day.")]
84-
85-
# Create a response
86-
response = "The sky is blue."
87-
88-
# Create context
89-
ctx = ChatContext().add(Message("user", "What color is the sky?"))
90-
ctx = ctx.add(Message("assistant", response))
91-
92-
# Use factory function
93-
req = citation_check(docs, min_citation_coverage=0.5)
94-
95-
# Validate
96-
result = await req.validate(backend, ctx)
97-
98-
# Should work the same as CitationRequirement
99-
assert isinstance(result.score, float)
100-
assert result.reason is not None
101-
102-
10375
async def test_citation_requirement_empty_context():
10476
"""Test citation requirement with empty context."""
10577
# Create a mock backend - we don't need a real one for this test
@@ -169,13 +141,12 @@ async def test_citation_requirement_no_documents():
169141

170142

171143
async def test_citation_requirement_wrong_backend():
172-
"""Test citation requirement with non-HuggingFace backend."""
173-
try:
174-
from mellea.backends.ollama import OllamaBackend # type: ignore
175-
except ImportError:
176-
pytest.skip("Ollama backend not available")
144+
"""Test citation requirement with non-adapter backend."""
145+
from unittest.mock import Mock
177146

178-
backend = OllamaBackend(model_id="llama3.2") # type: ignore
147+
# Create a mock backend that doesn't support adapters
148+
backend = Mock()
149+
backend.__class__.__name__ = "MockBackend"
179150

180151
# Create documents
181152
docs = [Document(doc_id="doc1", text="The sky is blue.")]
@@ -190,10 +161,10 @@ async def test_citation_requirement_wrong_backend():
190161
# Validate
191162
result = await req.validate(backend, ctx)
192163

193-
# Should fail with clear error about backend requirement
164+
# Should fail with clear error about adapter requirement
194165
assert not result.as_bool()
195166
assert result.reason is not None
196-
assert "LocalHFBackend" in result.reason or "HuggingFace" in result.reason
167+
assert "adapter" in result.reason.lower()
197168

198169

199170
def test_citation_requirement_invalid_coverage():
@@ -256,40 +227,60 @@ async def test_citation_requirement_empty_response():
256227
# Validate
257228
result = await req.validate(backend, ctx)
258229

259-
# Empty response should pass (100% coverage of nothing)
230+
# Empty response should pass (considered to have adequate coverage)
260231
assert result.as_bool()
261232
assert result.score == 1.0
233+
assert result.reason is not None
234+
assert "adequate citation coverage" in result.reason.lower()
262235

263236

264-
@pytest.mark.huggingface
265-
@pytest.mark.llm
266-
@pytest.mark.requires_heavy_ram
267237
async def test_citation_requirement_threshold_boundary():
268-
"""Test citation requirement at exact threshold boundary."""
269-
backend = LocalHFBackend(model_id="ibm-granite/granite-4.0-micro")
238+
"""Test citation requirement at exact threshold boundary.
239+
240+
This test mocks the find_citations intrinsic to return a controlled
241+
result that produces exactly the threshold coverage (80%).
242+
"""
243+
from unittest.mock import Mock, patch
244+
245+
backend = Mock(spec=LocalHFBackend)
270246

271247
# Create documents
272248
docs = [Document(doc_id="doc1", text="The sky is blue during the day.")]
273249

274-
# Create a response
275-
response = "The sky is blue."
250+
# Create a response with 10 characters
251+
response = "1234567890"
276252

277253
# Create context
278254
ctx = ChatContext().add(Message("user", "What color is the sky?"))
279255
ctx = ctx.add(Message("assistant", response, documents=docs))
280256

281-
# Create requirement with specific threshold
282-
req = CitationRequirement(min_citation_coverage=0.8)
283-
284-
# Validate
285-
result = await req.validate(backend, ctx)
257+
# Mock find_citations to return exactly 8 characters cited (80% of 10)
258+
mock_citations = [
259+
{
260+
"response_begin": 0,
261+
"response_end": 8, # 8 characters cited
262+
"response_text": "12345678",
263+
"citation_doc_id": "doc1",
264+
"citation_text": "The sky is blue",
265+
}
266+
]
286267

287-
# Check that score is calculated
288-
assert isinstance(result.score, float)
289-
assert 0.0 <= result.score <= 1.0
268+
with patch(
269+
"mellea.stdlib.components.intrinsic.rag.find_citations",
270+
return_value=mock_citations,
271+
):
272+
# Test at exact threshold (0.8)
273+
req = CitationRequirement(min_citation_coverage=0.8)
274+
result = await req.validate(backend, ctx)
290275

291-
# Result should match threshold comparison
292-
if result.score >= 0.8:
276+
# At exact threshold, should pass (>= comparison)
293277
assert result.as_bool()
294-
else:
295-
assert not result.as_bool()
278+
assert result.score == 0.8
279+
280+
# Test just below threshold (0.81)
281+
req_above = CitationRequirement(min_citation_coverage=0.81)
282+
result_above = await req_above.validate(backend, ctx)
283+
284+
# Just below threshold, should fail
285+
assert not result_above.as_bool()
286+
assert result_above.score == 0.8

0 commit comments

Comments
 (0)