1+ from unittest .mock import Mock
2+
13import pytest
24
3- from mellea .stdlib .requirements .tool_reqs import _name2str
5+ from mellea .core import ModelOutputThunk , ModelToolCall
6+ from mellea .stdlib .context import ChatContext
7+ from mellea .stdlib .requirements .tool_reqs import (
8+ _name2str ,
9+ tool_arg_validator ,
10+ uses_tool ,
11+ )
12+
13+
14+ def _ctx_with_tool_calls (tool_calls : dict [str , ModelToolCall ] | None ) -> ChatContext :
15+ """Helper: build a ChatContext whose last output has the given tool_calls."""
16+ ctx = ChatContext ()
17+ return ctx .add (ModelOutputThunk (value = "" , tool_calls = tool_calls ))
18+
19+
20+ def _make_tool_call (name : str , args : dict ) -> ModelToolCall :
21+ """Helper: build a ModelToolCall with a mock func."""
22+ return ModelToolCall (name = name , func = Mock (), args = args )
23+
24+
25+ # --- _name2str ---
426
527
628def test_name2str ():
@@ -11,3 +33,155 @@ def test123():
1133
1234 assert _name2str (test123 ) == "test123"
1335 assert _name2str ("test1234" ) == "test1234"
36+
37+
38+ def test_name2str_type_error ():
39+ with pytest .raises (TypeError , match = "Expected Callable or str" ):
40+ _name2str (123 ) # type: ignore[arg-type]
41+
42+
43+ # --- uses_tool ---
44+
45+
46+ def test_uses_tool_present ():
47+ ctx = _ctx_with_tool_calls ({"get_weather" : _make_tool_call ("get_weather" , {})})
48+ req = uses_tool ("get_weather" )
49+ result = req .validation_fn (ctx )
50+ assert result .as_bool () is True
51+
52+
53+ def test_uses_tool_absent ():
54+ ctx = _ctx_with_tool_calls ({"get_weather" : _make_tool_call ("get_weather" , {})})
55+ req = uses_tool ("send_email" )
56+ result = req .validation_fn (ctx )
57+ assert result .as_bool () is False
58+
59+
60+ def test_uses_tool_no_tool_calls ():
61+ ctx = _ctx_with_tool_calls (None )
62+ req = uses_tool ("get_weather" )
63+ result = req .validation_fn (ctx )
64+ assert result .as_bool () is False
65+ assert "no tool calls" in result .reason .lower ()
66+
67+
68+ def test_uses_tool_callable_input ():
69+ def my_tool ():
70+ pass
71+
72+ ctx = _ctx_with_tool_calls ({"my_tool" : _make_tool_call ("my_tool" , {})})
73+ req = uses_tool (my_tool )
74+ result = req .validation_fn (ctx )
75+ assert result .as_bool () is True
76+
77+
78+ def test_uses_tool_check_only ():
79+ req = uses_tool ("get_weather" , check_only = True )
80+ assert req .check_only is True
81+
82+
83+ # --- tool_arg_validator ---
84+
85+
86+ def test_tool_arg_validator_valid ():
87+ ctx = _ctx_with_tool_calls (
88+ {"search" : _make_tool_call ("search" , {"query" : "hello" , "limit" : 10 })}
89+ )
90+ req = tool_arg_validator (
91+ description = "limit must be positive" ,
92+ tool_name = "search" ,
93+ arg_name = "limit" ,
94+ validation_fn = lambda v : v > 0 ,
95+ )
96+ result = req .validation_fn (ctx )
97+ assert result .as_bool () is True
98+
99+
100+ def test_tool_arg_validator_failed_validation ():
101+ ctx = _ctx_with_tool_calls (
102+ {"search" : _make_tool_call ("search" , {"query" : "hello" , "limit" : - 1 })}
103+ )
104+ req = tool_arg_validator (
105+ description = "limit must be positive" ,
106+ tool_name = "search" ,
107+ arg_name = "limit" ,
108+ validation_fn = lambda v : v > 0 ,
109+ )
110+ result = req .validation_fn (ctx )
111+ assert result .as_bool () is False
112+
113+
114+ def test_tool_arg_validator_missing_tool ():
115+ ctx = _ctx_with_tool_calls (
116+ {"search" : _make_tool_call ("search" , {"query" : "hello" })}
117+ )
118+ req = tool_arg_validator (
119+ description = "check email tool" ,
120+ tool_name = "send_email" ,
121+ arg_name = "to" ,
122+ validation_fn = lambda v : True ,
123+ )
124+ result = req .validation_fn (ctx )
125+ assert result .as_bool () is False
126+ assert "send_email" in result .reason
127+
128+
129+ def test_tool_arg_validator_missing_arg ():
130+ ctx = _ctx_with_tool_calls (
131+ {"search" : _make_tool_call ("search" , {"query" : "hello" })}
132+ )
133+ req = tool_arg_validator (
134+ description = "limit must exist" ,
135+ tool_name = "search" ,
136+ arg_name = "limit" ,
137+ validation_fn = lambda v : True ,
138+ )
139+ result = req .validation_fn (ctx )
140+ assert result .as_bool () is False
141+ assert "limit" in result .reason
142+
143+
144+ def test_tool_arg_validator_no_tool_calls ():
145+ ctx = _ctx_with_tool_calls (None )
146+ req = tool_arg_validator (
147+ description = "check tool" ,
148+ tool_name = "search" ,
149+ arg_name = "query" ,
150+ validation_fn = lambda v : True ,
151+ )
152+ result = req .validation_fn (ctx )
153+ assert result .as_bool () is False
154+
155+
156+ def test_tool_arg_validator_no_tool_name_all_pass ():
157+ ctx = _ctx_with_tool_calls (
158+ {
159+ "tool_a" : _make_tool_call ("tool_a" , {"x" : 5 }),
160+ "tool_b" : _make_tool_call ("tool_b" , {"x" : 10 }),
161+ }
162+ )
163+ req = tool_arg_validator (
164+ description = "x must be positive" ,
165+ tool_name = None ,
166+ arg_name = "x" ,
167+ validation_fn = lambda v : v > 0 ,
168+ )
169+ result = req .validation_fn (ctx )
170+ assert result .as_bool () is True
171+
172+
173+ def test_tool_arg_validator_no_tool_name_one_fails ():
174+ ctx = _ctx_with_tool_calls (
175+ {
176+ "tool_a" : _make_tool_call ("tool_a" , {"x" : 5 }),
177+ "tool_b" : _make_tool_call ("tool_b" , {"x" : - 1 }),
178+ }
179+ )
180+ req = tool_arg_validator (
181+ description = "x must be positive" ,
182+ tool_name = None ,
183+ arg_name = "x" ,
184+ validation_fn = lambda v : v > 0 ,
185+ )
186+ result = req .validation_fn (ctx )
187+ assert result .as_bool () is False
0 commit comments