Skip to content

Commit 1e1f039

Browse files
committed
feat: citation requirement
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
1 parent d971776 commit 1e1f039

4 files changed

Lines changed: 744 additions & 0 deletions

File tree

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# pytest: huggingface, llm, requires_heavy_ram
2+
"""Example demonstrating CitationRequirement for RAG workflows.
3+
4+
This example shows how to use CitationRequirement to validate that
5+
assistant responses properly cite their sources in RAG workflows.
6+
7+
Note: This example requires HuggingFace backend and access to the
8+
meta-llama/Llama-3.2-1B-Instruct model.
9+
"""
10+
11+
import asyncio
12+
13+
from mellea.backends.huggingface import LocalHFBackend
14+
from mellea.stdlib.components import Document, Message
15+
from mellea.stdlib.context import ChatContext
16+
from mellea.stdlib.requirements.rag import CitationRequirement, citation_check
17+
18+
19+
async def main():
20+
"""Demonstrate CitationRequirement usage."""
21+
print("=" * 70)
22+
print("CitationRequirement Example")
23+
print("=" * 70)
24+
25+
# Initialize HuggingFace backend
26+
print("\nInitializing HuggingFace backend...")
27+
backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct")
28+
29+
# Create documents
30+
docs = [
31+
Document(
32+
doc_id="doc1",
33+
title="Sky Facts",
34+
text="The sky appears blue during the day due to Rayleigh scattering.",
35+
),
36+
Document(
37+
doc_id="doc2",
38+
title="Grass Facts",
39+
text="Grass is typically green because of chlorophyll in the leaves.",
40+
),
41+
]
42+
43+
# Create a response that should have citations
44+
response = (
45+
"The sky appears blue during the day. "
46+
"Grass is green because it contains chlorophyll."
47+
)
48+
49+
# Create context
50+
print("\nCreating context with user question and assistant response...")
51+
ctx = ChatContext().add(Message("user", "What colors are the sky and grass?"))
52+
ctx = ctx.add(Message("assistant", response, documents=docs))
53+
54+
# Example 1: Using CitationRequirement directly
55+
print("\n--- Example 1: CitationRequirement with 70% coverage ---")
56+
req = CitationRequirement(min_citation_coverage=0.7, documents=docs)
57+
result = await req.validate(backend, ctx)
58+
59+
print(f"Validation passed: {result.as_bool()}")
60+
print(f"Citation coverage score: {result.score:.2%}")
61+
if result.reason:
62+
reason_preview = (
63+
result.reason[:200] + "..." if len(result.reason) > 200 else result.reason
64+
)
65+
print(f"Reason: {reason_preview}")
66+
67+
# Example 2: Using citation_check factory
68+
print("\n--- Example 2: Using citation_check factory ---")
69+
req2 = citation_check(docs, min_citation_coverage=0.8)
70+
result2 = await req2.validate(backend, ctx)
71+
72+
print(f"Validation passed: {result2.as_bool()}")
73+
print(f"Citation coverage score: {result2.score:.2%}")
74+
75+
# Example 3: Documents attached to message
76+
print("\n--- Example 3: Documents in message (not constructor) ---")
77+
ctx2 = ChatContext().add(Message("user", "Tell me about Mars."))
78+
ctx2 = ctx2.add(
79+
Message(
80+
"assistant",
81+
"Mars is the fourth planet from the Sun.",
82+
documents=[
83+
Document(doc_id="doc1", text="Mars is the fourth planet from the Sun.")
84+
],
85+
)
86+
)
87+
88+
req3 = CitationRequirement(min_citation_coverage=0.7) # No documents in constructor
89+
result3 = await req3.validate(backend, ctx2)
90+
91+
print(f"Validation passed: {result3.as_bool()}")
92+
print(f"Citation coverage score: {result3.score:.2%}")
93+
94+
print("\n" + "=" * 70)
95+
print("Example completed successfully!")
96+
print("=" * 70)
97+
98+
99+
if __name__ == "__main__":
100+
asyncio.run(main())
101+
102+
# Made with Bob

mellea/stdlib/requirements/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ...core import Requirement, ValidationResult, default_output_to_bool
55
from .md import as_markdown_list, is_markdown_list, is_markdown_table
66
from .python_reqs import PythonExecutionReq
7+
from .rag import CitationRequirement, citation_check
78
from .requirement import (
89
ALoraRequirement,
910
LLMaJRequirement,
@@ -17,12 +18,14 @@
1718

1819
__all__ = [
1920
"ALoraRequirement",
21+
"CitationRequirement",
2022
"LLMaJRequirement",
2123
"PythonExecutionReq",
2224
"Requirement",
2325
"ValidationResult",
2426
"as_markdown_list",
2527
"check",
28+
"citation_check",
2629
"default_output_to_bool",
2730
"is_markdown_list",
2831
"is_markdown_table",

0 commit comments

Comments
 (0)