Skip to content

Commit ea9b16b

Browse files
committed
test: add 199 unit tests for granite formatters (#812)
Unit tests for Granite 3.2 and 3.3 input/output processors, shared utilities, IntrinsicsResultProcessor canned data regression, and OpenAI SDK compatibility. No GPU, network, or model downloads required. - test_granite3_shared.py: find_substring_in_text, create_dict, parse_hallucinations_text, hallucination/citation span helpers - test_granite32_output.py: citation parsing, model output splitting, validation, transform pipeline - test_granite33_output.py: JSON-delimited citations/hallucinations, think/response extraction, controls cleanup - test_granite32_input.py: system message matrix, sanitize, transform - test_granite33_input.py: available_tools role, per-document roles - test_intrinsics_canned_output.py: canned model outputs through IntrinsicsResultProcessor, Pydantic schema validation of fixtures - test_openai_compat.py: ChatCompletion round-trip through OpenAI SDK Closes #812
1 parent 417b7c8 commit ea9b16b

7 files changed

Lines changed: 1738 additions & 0 deletions
Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
"""Unit tests for the Granite 3.2 input processor."""
4+
5+
import pytest
6+
7+
from mellea.formatters.granite.base.types import (
8+
AssistantMessage,
9+
Document,
10+
SystemMessage,
11+
ToolDefinition,
12+
UserMessage,
13+
VLLMExtraBody,
14+
)
15+
from mellea.formatters.granite.granite3.constants import (
16+
NO_TOOLS_AND_DOCS_SYSTEM_MESSAGE_PART,
17+
NO_TOOLS_NO_DOCS_NO_THINKING_SYSTEM_MESSAGE_PART,
18+
)
19+
from mellea.formatters.granite.granite3.granite32.constants import (
20+
DOCS_AND_CITATIONS_SYSTEM_MESSAGE_PART,
21+
DOCS_AND_HALLUCINATIONS_SYSTEM_MESSAGE_PART,
22+
NO_TOOLS_AND_NO_DOCS_AND_THINKING_SYSTEM_MESSAGE_PART,
23+
TOOLS_AND_DOCS_SYSTEM_MESSAGE_PART,
24+
TOOLS_AND_NO_DOCS_SYSTEM_MESSAGE_PART,
25+
)
26+
from mellea.formatters.granite.granite3.granite32.input import Granite32InputProcessor
27+
from mellea.formatters.granite.granite3.granite32.types import Granite32ChatCompletion
28+
from mellea.formatters.granite.granite3.types import Granite3Controls, Granite3Kwargs
29+
30+
31+
def _make_completion(**kwargs) -> Granite32ChatCompletion:
32+
"""Helper to build a Granite32ChatCompletion with sensible defaults."""
33+
if "messages" not in kwargs:
34+
kwargs["messages"] = [UserMessage(content="Hello")]
35+
return Granite32ChatCompletion(**kwargs)
36+
37+
38+
# ---------------------------------------------------------------------------
39+
# _build_default_system_message
40+
# ---------------------------------------------------------------------------
41+
42+
43+
class TestBuildDefaultSystemMessage:
44+
def setup_method(self):
45+
self.proc = Granite32InputProcessor()
46+
47+
def test_no_tools_no_docs_no_thinking(self):
48+
cc = _make_completion()
49+
msg = self.proc._build_default_system_message(cc)
50+
assert "<|start_of_role|>system<|end_of_role|>" in msg
51+
assert NO_TOOLS_NO_DOCS_NO_THINKING_SYSTEM_MESSAGE_PART in msg
52+
assert msg.endswith("<|end_of_text|>\n")
53+
54+
def test_tools_only(self):
55+
cc = _make_completion(tools=[ToolDefinition(name="my_tool")])
56+
msg = self.proc._build_default_system_message(cc)
57+
assert TOOLS_AND_NO_DOCS_SYSTEM_MESSAGE_PART in msg
58+
59+
def test_docs_only(self):
60+
cc = _make_completion(
61+
extra_body=VLLMExtraBody(documents=[Document(text="Some document.")])
62+
)
63+
msg = self.proc._build_default_system_message(cc)
64+
assert NO_TOOLS_AND_DOCS_SYSTEM_MESSAGE_PART in msg
65+
66+
def test_tools_and_docs(self):
67+
cc = _make_completion(
68+
tools=[ToolDefinition(name="my_tool")],
69+
extra_body=VLLMExtraBody(documents=[Document(text="Some document.")]),
70+
)
71+
msg = self.proc._build_default_system_message(cc)
72+
assert TOOLS_AND_DOCS_SYSTEM_MESSAGE_PART in msg
73+
74+
def test_thinking_only(self):
75+
cc = _make_completion(
76+
extra_body=VLLMExtraBody(chat_template_kwargs=Granite3Kwargs(thinking=True))
77+
)
78+
msg = self.proc._build_default_system_message(cc)
79+
assert NO_TOOLS_AND_NO_DOCS_AND_THINKING_SYSTEM_MESSAGE_PART in msg
80+
81+
def test_docs_and_citations(self):
82+
cc = _make_completion(
83+
extra_body=VLLMExtraBody(
84+
documents=[Document(text="doc")],
85+
chat_template_kwargs=Granite3Kwargs(
86+
controls=Granite3Controls(citations=True)
87+
),
88+
)
89+
)
90+
msg = self.proc._build_default_system_message(cc)
91+
assert DOCS_AND_CITATIONS_SYSTEM_MESSAGE_PART in msg
92+
93+
def test_docs_citations_and_hallucinations(self):
94+
cc = _make_completion(
95+
extra_body=VLLMExtraBody(
96+
documents=[Document(text="doc")],
97+
chat_template_kwargs=Granite3Kwargs(
98+
controls=Granite3Controls(citations=True, hallucinations=True)
99+
),
100+
)
101+
)
102+
msg = self.proc._build_default_system_message(cc)
103+
assert DOCS_AND_CITATIONS_SYSTEM_MESSAGE_PART in msg
104+
assert DOCS_AND_HALLUCINATIONS_SYSTEM_MESSAGE_PART in msg
105+
106+
def test_thinking_with_docs_raises(self):
107+
cc = _make_completion(
108+
extra_body=VLLMExtraBody(
109+
documents=[Document(text="doc")],
110+
chat_template_kwargs=Granite3Kwargs(thinking=True),
111+
)
112+
)
113+
with pytest.raises(ValueError, match="thinking"):
114+
self.proc._build_default_system_message(cc)
115+
116+
def test_thinking_with_tools_raises(self):
117+
cc = _make_completion(
118+
tools=[ToolDefinition(name="tool")],
119+
extra_body=VLLMExtraBody(
120+
chat_template_kwargs=Granite3Kwargs(thinking=True)
121+
),
122+
)
123+
with pytest.raises(ValueError, match="thinking"):
124+
self.proc._build_default_system_message(cc)
125+
126+
def test_hallucinations_without_docs_raises(self):
127+
cc = _make_completion(
128+
extra_body=VLLMExtraBody(
129+
chat_template_kwargs=Granite3Kwargs(
130+
controls=Granite3Controls(hallucinations=True)
131+
)
132+
)
133+
)
134+
with pytest.raises(ValueError, match="hallucinations"):
135+
self.proc._build_default_system_message(cc)
136+
137+
def test_citations_without_docs_no_error(self):
138+
"""Citations without docs is silently skipped per the Jinja template."""
139+
cc = _make_completion(
140+
extra_body=VLLMExtraBody(
141+
chat_template_kwargs=Granite3Kwargs(
142+
controls=Granite3Controls(citations=True)
143+
)
144+
)
145+
)
146+
msg = self.proc._build_default_system_message(cc)
147+
assert DOCS_AND_CITATIONS_SYSTEM_MESSAGE_PART not in msg
148+
149+
150+
# ---------------------------------------------------------------------------
151+
# _remove_special_tokens
152+
# ---------------------------------------------------------------------------
153+
154+
155+
class TestRemoveSpecialTokens32:
156+
def test_removes_role_markers(self):
157+
text = "<|start_of_role|>system<|end_of_role|>content<|end_of_text|>"
158+
result = Granite32InputProcessor._remove_special_tokens(text)
159+
assert result == ""
160+
161+
def test_removes_tool_call_marker(self):
162+
text = '<|tool_call|>{"name":"func"}'
163+
result = Granite32InputProcessor._remove_special_tokens(text)
164+
assert result == ""
165+
166+
def test_removes_stray_special_tokens(self):
167+
text = "Hello <|end_of_text|> world <fim_prefix> end"
168+
result = Granite32InputProcessor._remove_special_tokens(text)
169+
assert "<|end_of_text|>" not in result
170+
assert "<fim_prefix>" not in result
171+
assert "Hello" in result
172+
173+
def test_clean_text_unchanged(self):
174+
text = "Just normal text."
175+
assert Granite32InputProcessor._remove_special_tokens(text) == text
176+
177+
178+
# ---------------------------------------------------------------------------
179+
# sanitize
180+
# ---------------------------------------------------------------------------
181+
182+
183+
class TestSanitize32:
184+
def test_sanitizes_messages(self):
185+
cc = _make_completion(
186+
messages=[UserMessage(content="Hello <|end_of_text|> world")]
187+
)
188+
sanitized = Granite32InputProcessor.sanitize(cc, parts="messages")
189+
assert "<|end_of_text|>" not in sanitized.messages[0].content
190+
191+
def test_sanitizes_all_by_default(self):
192+
cc = _make_completion(
193+
messages=[UserMessage(content="msg <|end_of_text|>")],
194+
tools=[ToolDefinition(name="tool <|end_of_text|>")],
195+
)
196+
sanitized = Granite32InputProcessor.sanitize(cc)
197+
assert "<|end_of_text|>" not in sanitized.messages[0].content
198+
assert "<|end_of_text|>" not in sanitized.tools[0].name
199+
200+
201+
# ---------------------------------------------------------------------------
202+
# transform
203+
# ---------------------------------------------------------------------------
204+
205+
206+
class TestTransform32:
207+
def setup_method(self):
208+
self.proc = Granite32InputProcessor()
209+
210+
def test_basic_user_message(self):
211+
cc = _make_completion()
212+
result = self.proc.transform(cc)
213+
assert "<|start_of_role|>system<|end_of_role|>" in result
214+
assert "<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>" in result
215+
assert result.endswith("<|start_of_role|>assistant<|end_of_role|>")
216+
217+
def test_custom_system_message(self):
218+
cc = _make_completion(
219+
messages=[
220+
SystemMessage(content="You are a pirate."),
221+
UserMessage(content="Ahoy!"),
222+
]
223+
)
224+
result = self.proc.transform(cc)
225+
assert "You are a pirate." in result
226+
assert "Ahoy!" in result
227+
228+
def test_custom_system_with_thinking_raises(self):
229+
cc = _make_completion(
230+
messages=[
231+
SystemMessage(content="Custom system."),
232+
UserMessage(content="Q"),
233+
],
234+
extra_body=VLLMExtraBody(
235+
chat_template_kwargs=Granite3Kwargs(thinking=True)
236+
),
237+
)
238+
with pytest.raises(ValueError, match="thinking"):
239+
self.proc.transform(cc)
240+
241+
def test_custom_system_with_docs_raises(self):
242+
cc = _make_completion(
243+
messages=[
244+
SystemMessage(content="Custom system."),
245+
UserMessage(content="Q"),
246+
],
247+
extra_body=VLLMExtraBody(documents=[Document(text="doc")]),
248+
)
249+
with pytest.raises(ValueError, match="documents"):
250+
self.proc.transform(cc)
251+
252+
def test_custom_system_with_citations_raises(self):
253+
cc = _make_completion(
254+
messages=[
255+
SystemMessage(content="Custom system."),
256+
UserMessage(content="Q"),
257+
],
258+
extra_body=VLLMExtraBody(
259+
chat_template_kwargs=Granite3Kwargs(
260+
controls=Granite3Controls(citations=True)
261+
)
262+
),
263+
)
264+
with pytest.raises(ValueError, match="citations"):
265+
self.proc.transform(cc)
266+
267+
def test_custom_system_with_hallucinations_raises(self):
268+
cc = _make_completion(
269+
messages=[
270+
SystemMessage(content="Custom system."),
271+
UserMessage(content="Q"),
272+
],
273+
extra_body=VLLMExtraBody(
274+
chat_template_kwargs=Granite3Kwargs(
275+
controls=Granite3Controls(hallucinations=True)
276+
)
277+
),
278+
)
279+
with pytest.raises(ValueError, match="hallucinations"):
280+
self.proc.transform(cc)
281+
282+
def test_tools_section(self):
283+
tool = ToolDefinition(name="get_weather", description="Get weather info")
284+
cc = _make_completion(tools=[tool])
285+
result = self.proc.transform(cc)
286+
assert "<|start_of_role|>tools<|end_of_role|>" in result
287+
assert "get_weather" in result
288+
289+
def test_documents_section(self):
290+
cc = _make_completion(
291+
extra_body=VLLMExtraBody(
292+
documents=[Document(text="First doc."), Document(text="Second doc.")]
293+
)
294+
)
295+
result = self.proc.transform(cc)
296+
assert "<|start_of_role|>documents<|end_of_role|>" in result
297+
assert "Document 0\nFirst doc." in result
298+
assert "Document 1\nSecond doc." in result
299+
300+
def test_controls_in_assistant_role(self):
301+
cc = _make_completion(
302+
extra_body=VLLMExtraBody(
303+
documents=[Document(text="doc")],
304+
chat_template_kwargs=Granite3Kwargs(
305+
controls=Granite3Controls(citations=True)
306+
),
307+
)
308+
)
309+
result = self.proc.transform(cc)
310+
assert '<|start_of_role|>assistant {"citations": true}<|end_of_role|>' in result
311+
312+
def test_no_generation_prompt(self):
313+
cc = _make_completion()
314+
result = self.proc.transform(cc, add_generation_prompt=False)
315+
assert not result.endswith("<|end_of_role|>")
316+
317+
def test_multi_turn_conversation(self):
318+
cc = _make_completion(
319+
messages=[
320+
UserMessage(content="Hi"),
321+
AssistantMessage(content="Hello!"),
322+
UserMessage(content="How are you?"),
323+
]
324+
)
325+
result = self.proc.transform(cc)
326+
assert result.count("<|start_of_role|>user<|end_of_role|>") == 2
327+
assert "<|start_of_role|>assistant<|end_of_role|>Hello!" in result

0 commit comments

Comments
 (0)