@@ -39,9 +39,9 @@ def __init__(self, cfg):
3939
4040 self .tokens_counter = Counter ()
4141
42- self .is_speculate_decoding = self .cfg .get_model_config ().get ( " speculate_method" ) is not None
42+ self .is_speculate_decoding = self .cfg .get_speculate_config ().speculate_method != " None"
4343 if self .is_speculate_decoding :
44- self .output_tokens = paddle .full (shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2 ], fill_value = 2 , dtype = "int64" )
44+ self .output_tokens = paddle .full (shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2 , 1 ], fill_value = 2 , dtype = "int64" )
4545 else :
4646 self .output_tokens = paddle .full (shape = [self .cfg .max_batch_size + 2 , 1 ], fill_value = 2 , dtype = "int64" )
4747 self .worker = None
@@ -71,10 +71,7 @@ def run(self):
7171 if self .worker is not None :
7272 raise Exception ("Worker is already running!" )
7373
74- if self .is_speculate_decoding :
75- self .worker = threading .Thread (target = self .process_speculate_results , args = ())
76- else :
77- self .worker = threading .Thread (target = self .process_sampling_results , args = ())
74+ self .worker = threading .Thread (target = self .process_sampling_results , args = ())
7875 self .worker .daemon = True
7976 self .worker .start ()
8077
@@ -86,30 +83,18 @@ def process_sampling_results(self):
8683 try :
8784 rank_id = 0
8885 is_blocking = True
89- get_output (self .output_tokens , rank_id , is_blocking )
86+ if self .is_speculate_decoding :
87+ speculate_get_output (self .output_tokens , rank_id , is_blocking )
88+ else :
89+ get_output (self .output_tokens , rank_id , is_blocking )
9090
9191 if self .output_tokens [0 , 0 ] == - 2 :
9292 continue
93+
9394 self ._process_batch_output ()
9495 except Exception as e :
9596 model_server_logger .info ("while get input_data error: {0} {1}" .format (e , str (traceback .format_exc ())))
9697
97- def process_speculate_results (self ):
98- """
99- read tokens from paddle inference engine and process
100- """
101- while True :
102- try :
103- rank_id = 0
104- is_blocking = True
105- speculate_get_output (self .output_tokens , rank_id , is_blocking )
106-
107- if self .output_tokens [0 ] == - 2 :
108- continue
109- self ._process_speculate_output ()
110- except Exception as e :
111- model_server_logger .info ("while get input_data error: {0} {1}" .format (e , str (traceback .format_exc ())))
112-
11398 def postprocess (self , batch_result , exist_finished_task = False ):
11499 """
115100 single post-processing function
@@ -126,73 +111,14 @@ def postprocess(self, batch_result, exist_finished_task=False):
126111 with open (result_file , "a" ) as f :
127112 f .write ("{}\n " .format (result ))
128113
129- def _get_single_result (self , i , task_id , token_id , task ):
114+ def _get_single_result (self , i , task_id , token_ids , task ):
130115 """
131116 processing single results
132117
133118 Args:
134119 i (int): batch index
135120 task_id (str): task id
136- token_id (int): token id
137- task (dict): task information
138-
139- Returns:
140- dict: result
141- """
142- inference_time_cost = time .time () - task ["inference_start_time" ]
143- task ["inference_time_cost" ] = inference_time_cost
144- task ["tokens_all_num" ] = len (self .all_tokens [i ])
145- task ["inference_current_step_time" ] = datetime .now ()
146- result = {
147- "req_id" : task_id ,
148- "is_end" : 0 ,
149- "token_ids" : [token_id ],
150- "send_idx" : self .tokens_counter [task_id ],
151- "inference_time_cost" : inference_time_cost ,
152- "infer_seed" : task ["infer_seed" ],
153- "return_all_tokens" : task .get ("return_all_tokens" , False ),
154- }
155-
156- # get benchmark msg
157- if task .get ("benchmark" ):
158- keys = ["preprocess_start_time" , "preprocess_end_time" , "schedule_start_time" ,
159- "inference_start_time" , "inference_current_step_time" ]
160- for key in keys :
161- if key in task :
162- result [key ] = str (task [key ])
163-
164- # fill some extra information
165- if token_id in task ["eos_token_ids" ]:
166- result ["is_end" ] = 1
167- result ["token_ids" ] = []
168- result ["tokens_all_num" ] = len (self .all_tokens [i ]) + 1
169- result ["tokens_all_ids" ] = self .all_tokens [i ]
170-
171- info_dict = {}
172- info_dict ["req_id" ] = task ["req_id" ]
173- info_dict ["input_token_num" ] = len (task ["input_ids" ])
174- info_dict ["output_token_num" ] = len (self .all_tokens [i ])
175- if hasattr (task , "preprocess_start_time" ) and hasattr (task , "preprocess_end_time" ):
176- info_dict ["preprocess_cost_time" ] = datetime_diff (task ["preprocess_start_time" ],
177- task ["preprocess_end_time" ])
178- if hasattr (task , "preprocess_end_time" ) and hasattr (task , "schedule_start_time" ):
179- info_dict ["cache_waiting_cost_time" ] = datetime_diff (task ["preprocess_end_time" ],
180- task ["schedule_start_time" ])
181- info_dict ["inference_time_cost" ] = task ["inference_time_cost" ]
182- info_dict ["version" ] = "4.6"
183- info_dict ["timestamp" ] = time .time ()
184- monitor_logger .info (f"{ info_dict } " )
185-
186- return result
187-
188- def _get_speculate_result (self , i , task_id , token_ids , task ):
189- """
190- processing single speculate results
191-
192- Args:
193- i (int): batch index
194- task_id (str): task id
195- token_ids (int): tokens id
121+ token_ids (list): token id
196122 task (dict): task information
197123
198124 Returns:
@@ -220,23 +146,23 @@ def _get_speculate_result(self, i, task_id, token_ids, task):
220146 if key in task :
221147 result [key ] = str (task [key ])
222148
223-
224- # fill some extra information when generate eos token
149+ # fill some extra information
225150 result ["token_ids" ] = []
226151 for token_id in token_ids :
227152 if token_id in task ["eos_token_ids" ]:
228153 result ["is_end" ] = 1
154+ result ["token_ids" ] = []
229155 result ["tokens_all_num" ] = len (self .all_tokens [i ]) + 1
230156 result ["tokens_all_ids" ] = self .all_tokens [i ]
231157
232158 info_dict = {}
233159 info_dict ["req_id" ] = task ["req_id" ]
234160 info_dict ["input_token_num" ] = len (task ["input_ids" ])
235161 info_dict ["output_token_num" ] = len (self .all_tokens [i ])
236- if "preprocess_start_time" in task and "preprocess_end_time" in task :
162+ if hasattr ( task , "preprocess_start_time" ) and hasattr ( task , "preprocess_end_time" ) :
237163 info_dict ["preprocess_cost_time" ] = datetime_diff (task ["preprocess_start_time" ],
238164 task ["preprocess_end_time" ])
239- if "preprocess_end_time" in task and "schedule_start_time" in task :
165+ if hasattr ( task , "preprocess_end_time" ) and hasattr ( task , "schedule_start_time" ) :
240166 info_dict ["cache_waiting_cost_time" ] = datetime_diff (task ["preprocess_end_time" ],
241167 task ["schedule_start_time" ])
242168 info_dict ["inference_time_cost" ] = task ["inference_time_cost" ]
@@ -266,74 +192,36 @@ def _process_batch_output(self):
266192 """
267193 tokens = self .output_tokens .numpy ()
268194 batch = self .output_tokens [1 , 0 ]
269- tokens = tokens [2 :batch + 2 ]
195+ if not self .is_speculate_decoding :
196+ tokens = tokens [2 :batch + 2 ]
197+ else :
198+ accept_num = tokens [2 :batch + 2 ]
270199
271200 batch_result = list ()
272201 exist_finished_task = False
273202 for i in range (batch ):
274203 if self .resource_manager .stop_flags [i ]:
275204 continue
276205
277- token_id = int (tokens [i , 0 ])
278- if token_id < 0 :
206+ if not self .is_speculate_decoding :
207+ token_ids = [int (tokens [i , 0 ])]
208+ else :
209+ token_ids = tokens [2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS : 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num [i , 0 ], 0 ].tolist ()
210+
211+ if any (token_id < 0 for token_id in token_ids ):
279212 continue
280213
281214 task = self .resource_manager .tasks_list [i ]
282215
283216 task_id = task ["req_id" ]
284- result = self ._get_single_result (i , task_id , token_id , task )
285-
286- self .tokens_counter [task_id ] += 1
287- if token_id not in task ["eos_token_ids" ]:
288- self .all_tokens [i ].append (token_id )
217+ result = self ._get_single_result (i , task_id , token_ids , task )
289218
290- self .number_of_output_tokens += 1
291- if token_id in task ["eos_token_ids" ]:
292- self ._recycle_resources (task_id , i , task )
293- model_server_logger .info ("req_id: {0} finished" .format (task_id ))
294- model_server_logger .info (f"{ self .resource_manager .info ()} " )
295- exist_finished_task = True
296- batch_result .append (result )
297-
298- self .postprocess (batch_result , exist_finished_task )
299-
300- def _process_speculate_output (self ):
301- """
302- batch post-processing function
303- """
304- tokens = self .output_tokens .numpy ()
305- batch = self .output_tokens [1 ]
306- output_token_msg_id = int (self .output_tokens [0 ])
307- accept_num = tokens [2 : batch + 2 ]
308- batch_result = list ()
309- # 用于判断当前此批结果中是否存在已完成的任务
310- exist_finished_task = False
311- prefill_mode = False
312- tasks_prefill = []
313-
314- for i in range (batch ):
315- # 对应task如若已结束,跳过
316- if self .resource_manager .stop_flags [i ]:
317- continue
318-
319- token_ids = tokens [2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS : 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num [i ]].tolist ()
320- # 跳过非法token
321- if len (token_ids ) == 0 or token_ids [- 1 ] == 0 :
322- continue
323-
324- task = self .resource_manager .tasks_list [i ]
325-
326- # 将会移至data server解决
327- task_id = task ["req_id" ]
328- result = self ._get_speculate_result (i , task_id , token_ids , task )
329-
330219 for token_id in token_ids :
331220 self .tokens_counter [task_id ] += 1
332221 if token_id not in task ["eos_token_ids" ]:
333222 self .all_tokens [i ].append (token_id )
334223
335224 self .number_of_output_tokens += 1
336- # 生成结束符时,重置相应变量
337225 if token_id in task ["eos_token_ids" ]:
338226 self ._recycle_resources (task_id , i , task )
339227 model_server_logger .info ("req_id: {0} finished" .format (task_id ))
@@ -342,7 +230,6 @@ def _process_speculate_output(self):
342230 break
343231 batch_result .append (result )
344232
345- # 后处理函数调用
346233 self .postprocess (batch_result , exist_finished_task )
347234
348235
@@ -365,29 +252,17 @@ def process_sampling_results(self):
365252 while self ._is_running :
366253 try :
367254 rank_id = 0
368- get_output (self .output_tokens , rank_id , self ._is_blocking )
255+ if self .is_speculate_decoding :
256+ speculate_get_output (self .output_tokens , rank_id , self ._is_blocking )
257+ else :
258+ get_output (self .output_tokens , rank_id , self ._is_blocking )
369259
370260 if self .output_tokens [0 , 0 ] == - 2 :
371261 continue
372262 self ._process_batch_output ()
373263 except Exception as e :
374264 model_server_logger .info ("while get input_data error: {0} {1}" .format (e , str (traceback .format_exc ())))
375265
376- def process_speculate_results (self ):
377- """
378- read tokens from paddle inference engine and process
379- """
380- while self ._is_running :
381- try :
382- rank_id = 0
383- speculate_get_output (self .output_tokens , rank_id , self ._is_blocking )
384-
385- if self .output_tokens [0 ] == - 2 :
386- continue
387- self ._process_speculate_output ()
388- except Exception as e :
389- model_server_logger .info ("while get input_data error: {0} {1}" .format (e , str (traceback .format_exc ())))
390-
391266 def stop (self ):
392267 """
393268 stop warm up thread
0 commit comments