2020from datetime import datetime
2121
2222import numpy as np
23- from paddlenlp_ops import get_output
23+ from paddlenlp_ops import get_output , speculate_get_output
2424from server .utils import datetime_diff , model_server_logger , monitor_logger
25+ from paddlenlp .utils .env import MAX_DRAFT_TOKENS , SPECULATE_MAX_BSZ
2526
2627
2728class TokenProcessor (object ):
@@ -37,7 +38,12 @@ def __init__(self, cfg):
3738 self .all_tokens = [[] for _ in range (self .cfg .max_batch_size )]
3839
3940 self .tokens_counter = Counter ()
40- self .output_tokens = paddle .full (shape = [self .cfg .max_batch_size + 2 , 1 ], fill_value = 2 , dtype = "int64" )
41+
42+ self .is_speculate_decoding = self .cfg .get_speculate_config ().speculate_method != "None"
43+ if self .is_speculate_decoding :
44+ self .output_tokens = paddle .full (shape = [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2 , 1 ], fill_value = 2 , dtype = "int64" )
45+ else :
46+ self .output_tokens = paddle .full (shape = [self .cfg .max_batch_size + 2 , 1 ], fill_value = 2 , dtype = "int64" )
4147 self .worker = None
4248
4349 self .record_time_interval = int (os .getenv ("RECORD_TIME_INTERVAL" , "600" ))
@@ -77,10 +83,14 @@ def process_sampling_results(self):
7783 try :
7884 rank_id = 0
7985 is_blocking = True
80- get_output (self .output_tokens , rank_id , is_blocking )
86+ if self .is_speculate_decoding :
87+ speculate_get_output (self .output_tokens , rank_id , is_blocking )
88+ else :
89+ get_output (self .output_tokens , rank_id , is_blocking )
8190
8291 if self .output_tokens [0 , 0 ] == - 2 :
8392 continue
93+
8494 self ._process_batch_output ()
8595 except Exception as e :
8696 model_server_logger .info ("while get input_data error: {0} {1}" .format (e , str (traceback .format_exc ())))
@@ -101,14 +111,14 @@ def postprocess(self, batch_result, exist_finished_task=False):
101111 with open (result_file , "a" ) as f :
102112 f .write ("{}\n " .format (result ))
103113
104- def _get_single_result (self , i , task_id , token_id , task ):
114+ def _get_single_result (self , i , task_id , token_ids , task ):
105115 """
106116 processing single results
107117
108118 Args:
109119 i (int): batch index
110120 task_id (str): task id
111- token_id (int ): token id
121+ token_ids (list ): token id
112122 task (dict): task information
113123
114124 Returns:
@@ -121,7 +131,7 @@ def _get_single_result(self, i, task_id, token_id, task):
121131 result = {
122132 "req_id" : task_id ,
123133 "is_end" : 0 ,
124- "token_ids" : [ token_id ] ,
134+ "token_ids" : token_ids ,
125135 "send_idx" : self .tokens_counter [task_id ],
126136 "inference_time_cost" : inference_time_cost ,
127137 "infer_seed" : task ["infer_seed" ],
@@ -137,26 +147,31 @@ def _get_single_result(self, i, task_id, token_id, task):
137147 result [key ] = str (task [key ])
138148
139149 # fill some extra information
140- if token_id in task ["eos_token_ids" ]:
141- result ["is_end" ] = 1
142- result ["token_ids" ] = []
143- result ["tokens_all_num" ] = len (self .all_tokens [i ]) + 1
144- result ["tokens_all_ids" ] = self .all_tokens [i ]
145-
146- info_dict = {}
147- info_dict ["req_id" ] = task ["req_id" ]
148- info_dict ["input_token_num" ] = len (task ["input_ids" ])
149- info_dict ["output_token_num" ] = len (self .all_tokens [i ])
150- if hasattr (task , "preprocess_start_time" ) and hasattr (task , "preprocess_end_time" ):
151- info_dict ["preprocess_cost_time" ] = datetime_diff (task ["preprocess_start_time" ],
152- task ["preprocess_end_time" ])
153- if hasattr (task , "preprocess_end_time" ) and hasattr (task , "schedule_start_time" ):
154- info_dict ["cache_waiting_cost_time" ] = datetime_diff (task ["preprocess_end_time" ],
155- task ["schedule_start_time" ])
156- info_dict ["inference_time_cost" ] = task ["inference_time_cost" ]
157- info_dict ["version" ] = "4.6"
158- info_dict ["timestamp" ] = time .time ()
159- monitor_logger .info (f"{ info_dict } " )
150+ result ["token_ids" ] = []
151+ for token_id in token_ids :
152+ if token_id in task ["eos_token_ids" ]:
153+ result ["is_end" ] = 1
154+ result ["token_ids" ] = []
155+ result ["tokens_all_num" ] = len (self .all_tokens [i ]) + 1
156+ result ["tokens_all_ids" ] = self .all_tokens [i ]
157+
158+ info_dict = {}
159+ info_dict ["req_id" ] = task ["req_id" ]
160+ info_dict ["input_token_num" ] = len (task ["input_ids" ])
161+ info_dict ["output_token_num" ] = len (self .all_tokens [i ])
162+ if hasattr (task , "preprocess_start_time" ) and hasattr (task , "preprocess_end_time" ):
163+ info_dict ["preprocess_cost_time" ] = datetime_diff (task ["preprocess_start_time" ],
164+ task ["preprocess_end_time" ])
165+ if hasattr (task , "preprocess_end_time" ) and hasattr (task , "schedule_start_time" ):
166+ info_dict ["cache_waiting_cost_time" ] = datetime_diff (task ["preprocess_end_time" ],
167+ task ["schedule_start_time" ])
168+ info_dict ["inference_time_cost" ] = task ["inference_time_cost" ]
169+ info_dict ["version" ] = "OpenSource"
170+ info_dict ["timestamp" ] = time .time ()
171+ monitor_logger .info (f"{ info_dict } " )
172+ break
173+ else :
174+ result ["token_ids" ].append (token_id )
160175
161176 return result
162177
@@ -177,33 +192,42 @@ def _process_batch_output(self):
177192 """
178193 tokens = self .output_tokens .numpy ()
179194 batch = self .output_tokens [1 , 0 ]
180- tokens = tokens [2 :batch + 2 ]
195+ if not self .is_speculate_decoding :
196+ tokens = tokens [2 :batch + 2 ]
197+ else :
198+ accept_num = tokens [2 :batch + 2 ]
181199
182200 batch_result = list ()
183201 exist_finished_task = False
184202 for i in range (batch ):
185203 if self .resource_manager .stop_flags [i ]:
186204 continue
187205
188- token_id = int (tokens [i , 0 ])
189- if token_id < 0 :
206+ if not self .is_speculate_decoding :
207+ token_ids = [int (tokens [i , 0 ])]
208+ else :
209+ token_ids = tokens [2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS : 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num [i , 0 ], 0 ].tolist ()
210+
211+ if any (token_id < 0 for token_id in token_ids ):
190212 continue
191213
192214 task = self .resource_manager .tasks_list [i ]
193215
194216 task_id = task ["req_id" ]
195- result = self ._get_single_result (i , task_id , token_id , task )
196-
197- self .tokens_counter [task_id ] += 1
198- if token_id not in task ["eos_token_ids" ]:
199- self .all_tokens [i ].append (token_id )
200-
201- self .number_of_output_tokens += 1
202- if token_id in task ["eos_token_ids" ]:
203- self ._recycle_resources (task_id , i , task )
204- model_server_logger .info ("req_id: {0} finished" .format (task_id ))
205- model_server_logger .info (f"{ self .resource_manager .info ()} " )
206- exist_finished_task = True
217+ result = self ._get_single_result (i , task_id , token_ids , task )
218+
219+ for token_id in token_ids :
220+ self .tokens_counter [task_id ] += 1
221+ if token_id not in task ["eos_token_ids" ]:
222+ self .all_tokens [i ].append (token_id )
223+
224+ self .number_of_output_tokens += 1
225+ if token_id in task ["eos_token_ids" ]:
226+ self ._recycle_resources (task_id , i , task )
227+ model_server_logger .info ("req_id: {0} finished" .format (task_id ))
228+ model_server_logger .info (f"{ self .resource_manager .info ()} " )
229+ exist_finished_task = True
230+ break
207231 batch_result .append (result )
208232
209233 self .postprocess (batch_result , exist_finished_task )
@@ -228,7 +252,10 @@ def process_sampling_results(self):
228252 while self ._is_running :
229253 try :
230254 rank_id = 0
231- get_output (self .output_tokens , rank_id , self ._is_blocking )
255+ if self .is_speculate_decoding :
256+ speculate_get_output (self .output_tokens , rank_id , self ._is_blocking )
257+ else :
258+ get_output (self .output_tokens , rank_id , self ._is_blocking )
232259
233260 if self .output_tokens [0 , 0 ] == - 2 :
234261 continue
0 commit comments