|
| 1 | +# pytest: huggingface, llm, requires_heavy_ram |
| 2 | +"""Example demonstrating CitationRequirement for RAG workflows. |
| 3 | +
|
| 4 | +This example shows how to use CitationRequirement to validate that |
| 5 | +assistant responses properly cite their sources in RAG workflows. |
| 6 | +
|
| 7 | +Note: This example requires HuggingFace backend and access to the |
| 8 | +meta-llama/Llama-3.2-1B-Instruct model. |
| 9 | +""" |
| 10 | + |
| 11 | +import asyncio |
| 12 | + |
| 13 | +from mellea.backends.huggingface import LocalHFBackend |
| 14 | +from mellea.stdlib.components import Document, Message |
| 15 | +from mellea.stdlib.context import ChatContext |
| 16 | +from mellea.stdlib.requirements.rag import CitationRequirement, citation_check |
| 17 | + |
| 18 | + |
| 19 | +async def main(): |
| 20 | + """Demonstrate CitationRequirement usage.""" |
| 21 | + print("=" * 70) |
| 22 | + print("CitationRequirement Example") |
| 23 | + print("=" * 70) |
| 24 | + |
| 25 | + # Initialize HuggingFace backend |
| 26 | + print("\nInitializing HuggingFace backend...") |
| 27 | + backend = LocalHFBackend(model_id="meta-llama/Llama-3.2-1B-Instruct") |
| 28 | + |
| 29 | + # Create documents |
| 30 | + docs = [ |
| 31 | + Document( |
| 32 | + doc_id="doc1", |
| 33 | + title="Sky Facts", |
| 34 | + text="The sky appears blue during the day due to Rayleigh scattering.", |
| 35 | + ), |
| 36 | + Document( |
| 37 | + doc_id="doc2", |
| 38 | + title="Grass Facts", |
| 39 | + text="Grass is typically green because of chlorophyll in the leaves.", |
| 40 | + ), |
| 41 | + ] |
| 42 | + |
| 43 | + # Create a response that should have citations |
| 44 | + response = ( |
| 45 | + "The sky appears blue during the day. " |
| 46 | + "Grass is green because it contains chlorophyll." |
| 47 | + ) |
| 48 | + |
| 49 | + # Create context |
| 50 | + print("\nCreating context with user question and assistant response...") |
| 51 | + ctx = ChatContext().add(Message("user", "What colors are the sky and grass?")) |
| 52 | + ctx = ctx.add(Message("assistant", response, documents=docs)) |
| 53 | + |
| 54 | + # Example 1: Using CitationRequirement directly |
| 55 | + print("\n--- Example 1: CitationRequirement with 70% coverage ---") |
| 56 | + req = CitationRequirement(min_citation_coverage=0.7, documents=docs) |
| 57 | + result = await req.validate(backend, ctx) |
| 58 | + |
| 59 | + print(f"Validation passed: {result.as_bool()}") |
| 60 | + print(f"Citation coverage score: {result.score:.2%}") |
| 61 | + if result.reason: |
| 62 | + reason_preview = ( |
| 63 | + result.reason[:200] + "..." if len(result.reason) > 200 else result.reason |
| 64 | + ) |
| 65 | + print(f"Reason: {reason_preview}") |
| 66 | + |
| 67 | + # Example 2: Using citation_check factory |
| 68 | + print("\n--- Example 2: Using citation_check factory ---") |
| 69 | + req2 = citation_check(docs, min_citation_coverage=0.8) |
| 70 | + result2 = await req2.validate(backend, ctx) |
| 71 | + |
| 72 | + print(f"Validation passed: {result2.as_bool()}") |
| 73 | + print(f"Citation coverage score: {result2.score:.2%}") |
| 74 | + |
| 75 | + # Example 3: Documents attached to message |
| 76 | + print("\n--- Example 3: Documents in message (not constructor) ---") |
| 77 | + ctx2 = ChatContext().add(Message("user", "Tell me about Mars.")) |
| 78 | + ctx2 = ctx2.add( |
| 79 | + Message( |
| 80 | + "assistant", |
| 81 | + "Mars is the fourth planet from the Sun.", |
| 82 | + documents=[ |
| 83 | + Document(doc_id="doc1", text="Mars is the fourth planet from the Sun.") |
| 84 | + ], |
| 85 | + ) |
| 86 | + ) |
| 87 | + |
| 88 | + req3 = CitationRequirement(min_citation_coverage=0.7) # No documents in constructor |
| 89 | + result3 = await req3.validate(backend, ctx2) |
| 90 | + |
| 91 | + print(f"Validation passed: {result3.as_bool()}") |
| 92 | + print(f"Citation coverage score: {result3.score:.2%}") |
| 93 | + |
| 94 | + print("\n" + "=" * 70) |
| 95 | + print("Example completed successfully!") |
| 96 | + print("=" * 70) |
| 97 | + |
| 98 | + |
| 99 | +if __name__ == "__main__": |
| 100 | + asyncio.run(main()) |
| 101 | + |
| 102 | +# Made with Bob |
0 commit comments