Skip to content

Commit aebe12a

Browse files
authored
[fix]update apply_chat_template (#4249)
* [fix]Modify follow-up push parameters and Modify the verification method for thinking length (#4086) * 续推参数 generated_token_ids 修改成 completion_token_ids;修改思考长度校验方式 * 续推参数 generated_token_ids 修改成 completion_token_ids;修改思考长度校验方式 * 续推参数 generated_token_ids 修改成 completion_token_ids;修改思考长度校验方式 * 续推参数 generated_token_ids 修改成 completion_token_ids;修改思考长度校验方式 * add completion_token_ids * add logger * fix reasoning_max_tokens ParameterError * add unittest * add unittest * add unittest * add unittest * add unittest * add unit test * fix * [fix]update apply_chat_template (#4137) * update apply_chat_template * fix unittest * fix unittest * fix * fix * fix unit test * fix * fix unit test * add unit test
1 parent 8fdb950 commit aebe12a

10 files changed

Lines changed: 146 additions & 109 deletions

File tree

fastdeploy/engine/engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ def add_requests(self, task, sampling_params=None, **kwargs):
222222
if sampling_params is not None:
223223
request.sampling_params = sampling_params
224224
request.preprocess_start_time = time.time()
225-
225+
chat_template_kwargs = kwargs.get("chat_template_kwargs") or {}
226+
chat_template_kwargs["chat_template"] = kwargs.get("chat_template")
227+
kwargs["chat_template_kwargs"] = chat_template_kwargs
226228
request = self.data_processor.process_request(request, self.cfg.max_model_len, **kwargs)
227229
request.prompt_token_ids_len = len(request.prompt_token_ids)
228230
request.need_prefill_tokens = request.prompt_token_ids_len

fastdeploy/entrypoints/engine_client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ async def add_requests(self, task):
172172

173173
task["preprocess_start_time"] = time.time()
174174
try:
175+
chat_template_kwargs = task.get("chat_template_kwargs", {})
176+
chat_template_kwargs.update({"chat_template": task.get("chat_template"), "tools": task.get("tools")})
177+
task["chat_template_kwargs"] = chat_template_kwargs
175178
if inspect.iscoroutinefunction(self.data_processor.process_request_dict):
176179
await self.data_processor.process_request_dict(task, self.max_model_len)
177180
else:

fastdeploy/input/ernie4_5_processor.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def process_request(self, request, max_model_len=None, **kwargs):
8888
str: error message
8989
"""
9090
data_processor_logger.info(f"Start processing request: {request}")
91-
request.chat_template = kwargs.get("chat_template")
9291
request = self._apply_default_parameters(request)
9392
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
9493
request.eos_token_ids = self.eos_token_ids
@@ -127,15 +126,15 @@ def process_request(self, request, max_model_len=None, **kwargs):
127126
)
128127
elif request.messages is not None:
129128
task = request.to_dict()
130-
chat_template_kwargs = kwargs.get("chat_template_kwargs")
129+
chat_template_kwargs = kwargs.get("chat_template_kwargs", {})
131130
if chat_template_kwargs:
132131
if isinstance(chat_template_kwargs, dict):
133132
for k, v in chat_template_kwargs.items():
134133
if k not in task:
135134
task[k] = v
136135
else:
137136
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
138-
request.prompt_token_ids = self.messages2ids(task)
137+
request.prompt_token_ids = self.messages2ids(task, **chat_template_kwargs)
139138
else:
140139
raise ValueError(f"The request should have `prompt_token_ids`, `prompt` or `messages`: {request}.")
141140

@@ -205,15 +204,15 @@ def process_request_dict(self, request, max_model_len=None):
205204
req_id = request.get("request_id", None)
206205
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
207206
elif request.get("messages"):
208-
chat_template_kwargs = request.get("chat_template_kwargs")
207+
chat_template_kwargs = request.get("chat_template_kwargs", {})
209208
if chat_template_kwargs:
210209
if isinstance(chat_template_kwargs, dict):
211210
for k, v in chat_template_kwargs.items():
212211
if k not in request:
213212
request[k] = v
214213
else:
215214
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
216-
request["prompt_token_ids"] = self.messages2ids(request)
215+
request["prompt_token_ids"] = self.messages2ids(request, **chat_template_kwargs)
217216
else:
218217
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
219218

@@ -379,7 +378,7 @@ def process_response_dict_streaming(self, response_dict, **kwargs):
379378
del self.tool_parser_dict[req_id]
380379
return response_dict
381380

382-
def messages2ids(self, request_or_messages):
381+
def messages2ids(self, request_or_messages, **kwargs):
383382
"""
384383
Convert multi-turn messages into ID sequences.
385384
@@ -397,7 +396,7 @@ def messages2ids(self, request_or_messages):
397396
tokenize=False,
398397
split_special_tokens=False,
399398
add_special_tokens=False,
400-
chat_template=request_or_messages.get("chat_template", None),
399+
**kwargs,
401400
)
402401
request_or_messages["text_after_process"] = spliced_message
403402
req_id = None

fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ def set_value(req, key, value):
113113

114114
def process_request(self, request, max_model_len=None, **kwargs):
115115
"""process the input data"""
116-
request.chat_template = kwargs.get("chat_template")
117116
task = request.to_dict()
118117
task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs")
119118
self.process_request_dict(task, max_model_len)

fastdeploy/input/ernie4_5_vl_processor/process.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ def request2ids(
250250
"video",
251251
]:
252252
image_message_list.append(item)
253-
254-
prompt_token_ids = self.apply_chat_template(request)
253+
chat_template_kwargs = request.get("chat_template_kwargs", {})
254+
prompt_token_ids = self.apply_chat_template(request, **chat_template_kwargs)
255255
if len(prompt_token_ids) == 0:
256256
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")
257257
image_start_index = 0
@@ -480,7 +480,7 @@ def _load_tokenizer(self):
480480
break
481481
self.tokenizer = Ernie4_5Tokenizer.from_pretrained(self.model_name_or_path)
482482

483-
def apply_chat_template(self, request):
483+
def apply_chat_template(self, request, **kwargs):
484484
"""
485485
Convert multi-turn messages into ID sequences.
486486
@@ -498,7 +498,7 @@ def apply_chat_template(self, request):
498498
request,
499499
tokenize=False,
500500
add_generation_prompt=request.get("add_generation_prompt", True),
501-
chat_template=request.get("chat_template", None),
501+
**kwargs,
502502
)
503503
prompt_token_str = prompt_token_template.replace("<|image@placeholder|>", "").replace(
504504
"<|video@placeholder|>", ""

fastdeploy/input/text_processor.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ def process_request(self, request, max_model_len=None, **kwargs):
208208
str: error message
209209
"""
210210
data_processor_logger.info(f"Start processing request: {request}")
211-
request.chat_template = kwargs.get("chat_template")
212211
request = self._apply_default_parameters(request)
213212
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
214213
request.eos_token_ids = self.eos_token_ids
@@ -242,7 +241,7 @@ def process_request(self, request, max_model_len=None, **kwargs):
242241
if self.tokenizer.chat_template is None:
243242
raise ValueError("This model does not support chat_template.")
244243
task = request.to_dict()
245-
chat_template_kwargs = kwargs.get("chat_template_kwargs")
244+
chat_template_kwargs = kwargs.get("chat_template_kwargs", {})
246245
if chat_template_kwargs:
247246
if isinstance(chat_template_kwargs, dict):
248247
for k, v in chat_template_kwargs.items():
@@ -251,7 +250,7 @@ def process_request(self, request, max_model_len=None, **kwargs):
251250
else:
252251
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
253252
task.setdefault("enable_thinking", True)
254-
request.prompt_token_ids = self.messages2ids(task)
253+
request.prompt_token_ids = self.messages2ids(task, **chat_template_kwargs)
255254
else:
256255
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
257256

@@ -316,7 +315,7 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
316315
elif request.get("messages"):
317316
if self.tokenizer.chat_template is None:
318317
raise ValueError("This model does not support chat_template.")
319-
chat_template_kwargs = request.get("chat_template_kwargs")
318+
chat_template_kwargs = request.get("chat_template_kwargs", {})
320319
if chat_template_kwargs:
321320
if isinstance(chat_template_kwargs, dict):
322321
for k, v in chat_template_kwargs.items():
@@ -325,7 +324,7 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
325324
else:
326325
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
327326
request.setdefault("enable_thinking", True)
328-
request["prompt_token_ids"] = self.messages2ids(request)
327+
request["prompt_token_ids"] = self.messages2ids(request, **chat_template_kwargs)
329328
else:
330329
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
331330

@@ -530,7 +529,7 @@ def text2ids(self, text, max_model_len):
530529

531530
return tokens["input_ids"][0]
532531

533-
def messages2ids(self, request):
532+
def messages2ids(self, request, **kwargs):
534533
"""
535534
Convert multi-turn messages into ID sequences.
536535
@@ -547,7 +546,7 @@ def messages2ids(self, request):
547546
split_special_tokens=False,
548547
add_special_tokens=False,
549548
return_tensors="pd",
550-
chat_template=request.get("chat_template", None),
549+
**kwargs,
551550
)
552551
request["text_after_process"] = spliced_message
553552
req_id = None
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
from fastdeploy.entrypoints.engine_client import EngineClient
5+
6+
7+
class TestEngineClient(unittest.IsolatedAsyncioTestCase):
8+
async def asyncSetUp(self):
9+
# 创建 EngineClient 实例的模拟对象
10+
with patch.object(EngineClient, "__init__", return_value=None) as mock_init:
11+
self.engine_client = EngineClient("model_path")
12+
mock_init.side_effect = lambda *args, **kwargs: print(f"__init__ called with {args}, {kwargs}")
13+
14+
self.engine_client.data_processor = MagicMock()
15+
self.engine_client.zmq_client = MagicMock()
16+
self.engine_client.max_model_len = 1024
17+
self.engine_client.enable_mm = False
18+
19+
async def test_add_request(self):
20+
request = {
21+
"chat_template_kwargs": {"enable_thinking": True},
22+
"prompt_token_ids": [1],
23+
"chat_template": "Hello",
24+
"max_tokens": 20,
25+
"tools": [1],
26+
}
27+
28+
await self.engine_client.add_requests(request)
29+
assert "chat_template" in request["chat_template_kwargs"], "'chat_template' not found in 'chat_template_kwargs"
30+
assert "tools" in request["chat_template_kwargs"], "'tools' not found in 'chat_template_kwargs'"
31+
assert request["chat_template_kwargs"]["chat_template"] == "Hello"
32+
assert request["chat_template_kwargs"]["tools"] == [1]
33+
34+
35+
if __name__ == "__main__":
36+
unittest.main()

tests/input/test_ernie_processor.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,27 @@ def setUp(self):
1717
self.processor.decode_status = {}
1818
self.processor.reasoning_end_dict = {}
1919
self.processor.tool_parser_dict = {}
20+
self.processor.generation_config = MagicMock()
21+
self.processor.eos_token_ids = [1]
2022

2123
# 模拟 ids2tokens 方法
2224
def mock_ids2tokens(token_ids, task_id):
2325
return "delta_text", [2, 3], "previous_texts"
2426

2527
self.processor.ids2tokens = mock_ids2tokens
2628

29+
def mock_messages2ids(request, **kwargs):
30+
if "chat_template" in kwargs:
31+
return [1]
32+
else:
33+
return [0]
34+
35+
def mock_apply_default_parameters(request):
36+
return request
37+
38+
self.processor.messages2ids = mock_messages2ids
39+
self.processor._apply_default_parameters = mock_apply_default_parameters
40+
2741
# 模拟推理解析器
2842
self.mock_reasoning_parser = MagicMock()
2943
self.mock_reasoning_parser.__class__.__name__ = "ErnieX1ReasoningParser"
@@ -49,6 +63,17 @@ def test_process_response_dict_streaming_normal_case(self):
4963
# 验证结果
5064
self.assertEqual(result["outputs"]["raw_prediction"], "delta_text")
5165

66+
def test_process_request_dict(self):
67+
request_dict = {
68+
"messages": [{"role": "user", "content": "Hello!"}],
69+
"chat_template_kwargs": {"chat_template": "Hello!"},
70+
"eos_token_ids": [1],
71+
"temperature": 1,
72+
"top_p": 1,
73+
}
74+
result = self.processor.process_request_dict(request_dict, 100)
75+
self.assertEqual(result["prompt_token_ids"], [1])
76+
5277

5378
if __name__ == "__main__":
5479
unittest.main()

tests/input/test_text_processor.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import unittest
2+
from unittest.mock import MagicMock, patch
3+
4+
from fastdeploy.engine.request import Request
5+
from fastdeploy.input.text_processor import DataProcessor
6+
7+
8+
class TestDataProcessorProcess(unittest.TestCase):
9+
def setUp(self):
10+
# 创建 DataProcessor 实例的模拟对象
11+
with patch.object(DataProcessor, "__init__", return_value=None) as mock_init:
12+
self.processor = DataProcessor("model_path")
13+
mock_init.side_effect = lambda *args, **kwargs: print(f"__init__ called with {args}, {kwargs}")
14+
15+
# 设置必要的属性
16+
self.processor.tokenizer = MagicMock()
17+
self.processor.tokenizer.eos_token_id = 1
18+
self.processor.decode_status = {}
19+
self.processor.reasoning_end_dict = {}
20+
self.processor.tool_parser_dict = {}
21+
self.processor.generation_config = MagicMock()
22+
self.processor.eos_token_ids = [1]
23+
24+
def mock_messages2ids(request, **kwargs):
25+
if "chat_template" in kwargs:
26+
return [1]
27+
else:
28+
return [0]
29+
30+
def mock_apply_default_parameters(request):
31+
return request
32+
33+
self.processor.messages2ids = mock_messages2ids
34+
self.processor._apply_default_parameters = mock_apply_default_parameters
35+
36+
def test_process_request(self):
37+
request = Request.from_dict(
38+
{
39+
"request_id": "123",
40+
"messages": [{"role": "user", "content": "Hello!"}],
41+
"eos_token_ids": [1],
42+
"temperature": 1,
43+
"top_p": 1,
44+
}
45+
)
46+
chat_template_kwargs = {"chat_template": "Hello!"}
47+
result = self.processor.process_request(request, 100, chat_template_kwargs=chat_template_kwargs)
48+
self.assertEqual(result.prompt_token_ids, [1])
49+
50+
def test_process_request_dict(self):
51+
request_dict = {
52+
"messages": [{"role": "user", "content": "Hello!"}],
53+
"chat_template_kwargs": {"chat_template": "Hello!"},
54+
"eos_token_ids": [1],
55+
"temperature": 1,
56+
"top_p": 1,
57+
}
58+
result = self.processor.process_request_dict(request_dict, 100)
59+
self.assertEqual(result["prompt_token_ids"], [1])
60+
61+
62+
if __name__ == "__main__":
63+
unittest.main()

0 commit comments

Comments
 (0)