Skip to content

Commit 5c42585

Browse files
update
1 parent ae57a2b commit 5c42585

3 files changed

Lines changed: 39 additions & 163 deletions

File tree

llm/server/server/engine/config.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,14 +211,15 @@ def get_speculate_config(self):
211211
SpeculateConfig: the speculate related config
212212
"""
213213
speculate_config = SpeculateConfig()
214-
if self.model_cfg.get("speculate_method") is not None:
215-
speculate_config.speculate_method = self.model_cfg["speculate_method"]
216-
speculate_config.speculate_max_draft_token_num = self.model_cfg[
214+
model_cfg = self.get_model_config()
215+
if model_cfg.get("speculate_method", "None") != "None":
216+
speculate_config.speculate_method = str(model_cfg["speculate_method"])
217+
speculate_config.speculate_max_draft_token_num = model_cfg[
217218
"speculate_max_draft_token_num"]
218-
speculate_config.speculate_max_ngram_size = self.model_cfg[
219+
speculate_config.speculate_max_ngram_size = model_cfg[
219220
"speculate_max_ngram_size"]
220221

221-
if speculate_config.speculate_method is not in ["none", "inference_with_reference"]:
222+
if speculate_config.speculate_method not in ["None", "inference_with_reference"]:
222223
model_server_logger.error(f"Unsupport speculate method: {speculate_config.speculate_method}")
223224

224225
return speculate_config
@@ -258,6 +259,6 @@ def __str__(self) -> str:
258259

259260
@dataclass
260261
class SpeculateConfig:
261-
speculate_method: str = None
262+
speculate_method: str = "None"
262263
speculate_max_draft_token_num: int = 1
263264
speculate_max_ngram_size: int = 1

llm/server/server/engine/infer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self, args):
4848
self.config = Config()
4949
self.model_cfg = self.config.get_model_config()
5050
self.speculate_config = self.config.get_speculate_config()
51-
self.is_speculate_decoding = self.speculate_config.speculate_method is not None
51+
self.is_speculate_decoding = self.speculate_config.speculate_method != "None"
5252
self.format_print_configuration()
5353

5454
self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"])
@@ -71,7 +71,7 @@ def __init__(self, args):
7171
self.init_inputs()
7272

7373
if self.is_speculate_decoding:
74-
logger.info(f'Using speculating decoding, method: {self.speculate_config.speculate_method}.')
74+
logger.info(f'Using speculate decoding, method: {self.speculate_config.speculate_method}.')
7575
if self.speculate_config.speculate_method == "inference_with_reference":
7676
self.proposer = InferenceWithReferenceProposer(
7777
self.speculate_config.speculate_max_draft_token_num,
@@ -371,7 +371,7 @@ def step_cuda(self, seq_lens_this_time):
371371
self.share_inputs['free_list'], self.share_inputs['free_list_len'],
372372
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
373373
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
374-
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id
374+
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id,
375375
speculate_step_token_num)
376376

377377
def initialize_engine_ready_check_flag(self):

llm/server/server/engine/token_processor.py

Lines changed: 29 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def __init__(self, cfg):
3939

4040
self.tokens_counter = Counter()
4141

42-
self.is_speculate_decoding = self.cfg.get_model_config().get("speculate_method") is not None
42+
self.is_speculate_decoding = self.cfg.get_speculate_config().speculate_method != "None"
4343
if self.is_speculate_decoding:
44-
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64")
44+
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1], fill_value=2, dtype="int64")
4545
else:
4646
self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")
4747
self.worker = None
@@ -71,10 +71,7 @@ def run(self):
7171
if self.worker is not None:
7272
raise Exception("Worker is already running!")
7373

74-
if self.is_speculate_decoding:
75-
self.worker = threading.Thread(target=self.process_speculate_results, args=())
76-
else:
77-
self.worker = threading.Thread(target=self.process_sampling_results, args=())
74+
self.worker = threading.Thread(target=self.process_sampling_results, args=())
7875
self.worker.daemon = True
7976
self.worker.start()
8077

@@ -86,30 +83,18 @@ def process_sampling_results(self):
8683
try:
8784
rank_id = 0
8885
is_blocking = True
89-
get_output(self.output_tokens, rank_id, is_blocking)
86+
if self.is_speculate_decoding:
87+
speculate_get_output(self.output_tokens, rank_id, is_blocking)
88+
else:
89+
get_output(self.output_tokens, rank_id, is_blocking)
9090

9191
if self.output_tokens[0, 0] == -2:
9292
continue
93+
9394
self._process_batch_output()
9495
except Exception as e:
9596
model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))
9697

97-
def process_speculate_results(self):
98-
"""
99-
read tokens from paddle inference engine and process
100-
"""
101-
while True:
102-
try:
103-
rank_id = 0
104-
is_blocking = True
105-
speculate_get_output(self.output_tokens, rank_id, is_blocking)
106-
107-
if self.output_tokens[0] == -2:
108-
continue
109-
self._process_speculate_output()
110-
except Exception as e:
111-
model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))
112-
11398
def postprocess(self, batch_result, exist_finished_task=False):
11499
"""
115100
single post-processing function
@@ -126,73 +111,14 @@ def postprocess(self, batch_result, exist_finished_task=False):
126111
with open(result_file, "a") as f:
127112
f.write("{}\n".format(result))
128113

129-
def _get_single_result(self, i, task_id, token_id, task):
114+
def _get_single_result(self, i, task_id, token_ids, task):
130115
"""
131116
processing single results
132117
133118
Args:
134119
i (int): batch index
135120
task_id (str): task id
136-
token_id (int): token id
137-
task (dict): task information
138-
139-
Returns:
140-
dict: result
141-
"""
142-
inference_time_cost = time.time() - task["inference_start_time"]
143-
task["inference_time_cost"] = inference_time_cost
144-
task["tokens_all_num"] = len(self.all_tokens[i])
145-
task["inference_current_step_time"] = datetime.now()
146-
result = {
147-
"req_id": task_id,
148-
"is_end": 0,
149-
"token_ids": [token_id],
150-
"send_idx": self.tokens_counter[task_id],
151-
"inference_time_cost": inference_time_cost,
152-
"infer_seed": task["infer_seed"],
153-
"return_all_tokens": task.get("return_all_tokens", False),
154-
}
155-
156-
# get benchmark msg
157-
if task.get("benchmark"):
158-
keys = ["preprocess_start_time", "preprocess_end_time", "schedule_start_time",
159-
"inference_start_time", "inference_current_step_time"]
160-
for key in keys:
161-
if key in task:
162-
result[key] = str(task[key])
163-
164-
# fill some extra information
165-
if token_id in task["eos_token_ids"]:
166-
result["is_end"] = 1
167-
result["token_ids"] = []
168-
result["tokens_all_num"] = len(self.all_tokens[i]) + 1
169-
result["tokens_all_ids"] = self.all_tokens[i]
170-
171-
info_dict = {}
172-
info_dict["req_id"] = task["req_id"]
173-
info_dict["input_token_num"] = len(task["input_ids"])
174-
info_dict["output_token_num"] = len(self.all_tokens[i])
175-
if hasattr(task, "preprocess_start_time") and hasattr(task, "preprocess_end_time"):
176-
info_dict["preprocess_cost_time"] = datetime_diff(task["preprocess_start_time"],
177-
task["preprocess_end_time"])
178-
if hasattr(task, "preprocess_end_time") and hasattr(task, "schedule_start_time"):
179-
info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"],
180-
task["schedule_start_time"])
181-
info_dict["inference_time_cost"] = task["inference_time_cost"]
182-
info_dict["version"] = "4.6"
183-
info_dict["timestamp"] = time.time()
184-
monitor_logger.info(f"{info_dict}")
185-
186-
return result
187-
188-
def _get_speculate_result(self, i, task_id, token_ids, task):
189-
"""
190-
processing single speculate results
191-
192-
Args:
193-
i (int): batch index
194-
task_id (str): task id
195-
token_ids (int): tokens id
121+
token_ids (list): token id
196122
task (dict): task information
197123
198124
Returns:
@@ -220,23 +146,23 @@ def _get_speculate_result(self, i, task_id, token_ids, task):
220146
if key in task:
221147
result[key] = str(task[key])
222148

223-
224-
# fill some extra information when generate eos token
149+
# fill some extra information
225150
result["token_ids"] = []
226151
for token_id in token_ids:
227152
if token_id in task["eos_token_ids"]:
228153
result["is_end"] = 1
154+
result["token_ids"] = []
229155
result["tokens_all_num"] = len(self.all_tokens[i]) + 1
230156
result["tokens_all_ids"] = self.all_tokens[i]
231157

232158
info_dict = {}
233159
info_dict["req_id"] = task["req_id"]
234160
info_dict["input_token_num"] = len(task["input_ids"])
235161
info_dict["output_token_num"] = len(self.all_tokens[i])
236-
if "preprocess_start_time" in task and "preprocess_end_time" in task:
162+
if hasattr(task, "preprocess_start_time") and hasattr(task, "preprocess_end_time"):
237163
info_dict["preprocess_cost_time"] = datetime_diff(task["preprocess_start_time"],
238164
task["preprocess_end_time"])
239-
if "preprocess_end_time" in task and "schedule_start_time" in task:
165+
if hasattr(task, "preprocess_end_time") and hasattr(task, "schedule_start_time"):
240166
info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"],
241167
task["schedule_start_time"])
242168
info_dict["inference_time_cost"] = task["inference_time_cost"]
@@ -266,74 +192,36 @@ def _process_batch_output(self):
266192
"""
267193
tokens = self.output_tokens.numpy()
268194
batch = self.output_tokens[1, 0]
269-
tokens = tokens[2:batch + 2]
195+
if not self.is_speculate_decoding:
196+
tokens = tokens[2:batch + 2]
197+
else:
198+
accept_num = tokens[2:batch + 2]
270199

271200
batch_result = list()
272201
exist_finished_task = False
273202
for i in range(batch):
274203
if self.resource_manager.stop_flags[i]:
275204
continue
276205

277-
token_id = int(tokens[i, 0])
278-
if token_id < 0:
206+
if not self.is_speculate_decoding:
207+
token_ids = [int(tokens[i, 0])]
208+
else:
209+
token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num[i, 0], 0].tolist()
210+
211+
if any(token_id < 0 for token_id in token_ids):
279212
continue
280213

281214
task = self.resource_manager.tasks_list[i]
282215

283216
task_id = task["req_id"]
284-
result = self._get_single_result(i, task_id, token_id, task)
285-
286-
self.tokens_counter[task_id] += 1
287-
if token_id not in task["eos_token_ids"]:
288-
self.all_tokens[i].append(token_id)
217+
result = self._get_single_result(i, task_id, token_ids, task)
289218

290-
self.number_of_output_tokens += 1
291-
if token_id in task["eos_token_ids"]:
292-
self._recycle_resources(task_id, i, task)
293-
model_server_logger.info("req_id: {0} finished".format(task_id))
294-
model_server_logger.info(f"{self.resource_manager.info()}")
295-
exist_finished_task = True
296-
batch_result.append(result)
297-
298-
self.postprocess(batch_result, exist_finished_task)
299-
300-
def _process_speculate_output(self):
301-
"""
302-
batch post-processing function
303-
"""
304-
tokens = self.output_tokens.numpy()
305-
batch = self.output_tokens[1]
306-
output_token_msg_id = int(self.output_tokens[0])
307-
accept_num = tokens[2 : batch + 2]
308-
batch_result = list()
309-
# 用于判断当前此批结果中是否存在已完成的任务
310-
exist_finished_task = False
311-
prefill_mode = False
312-
tasks_prefill = []
313-
314-
for i in range(batch):
315-
# 对应task如若已结束,跳过
316-
if self.resource_manager.stop_flags[i]:
317-
continue
318-
319-
token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num[i]].tolist()
320-
# 跳过非法token
321-
if len(token_ids) == 0 or token_ids[-1] == 0:
322-
continue
323-
324-
task = self.resource_manager.tasks_list[i]
325-
326-
# 将会移至data server解决
327-
task_id = task["req_id"]
328-
result = self._get_speculate_result(i, task_id, token_ids, task)
329-
330219
for token_id in token_ids:
331220
self.tokens_counter[task_id] += 1
332221
if token_id not in task["eos_token_ids"]:
333222
self.all_tokens[i].append(token_id)
334223

335224
self.number_of_output_tokens += 1
336-
# 生成结束符时,重置相应变量
337225
if token_id in task["eos_token_ids"]:
338226
self._recycle_resources(task_id, i, task)
339227
model_server_logger.info("req_id: {0} finished".format(task_id))
@@ -342,7 +230,6 @@ def _process_speculate_output(self):
342230
break
343231
batch_result.append(result)
344232

345-
# 后处理函数调用
346233
self.postprocess(batch_result, exist_finished_task)
347234

348235

@@ -365,29 +252,17 @@ def process_sampling_results(self):
365252
while self._is_running:
366253
try:
367254
rank_id = 0
368-
get_output(self.output_tokens, rank_id, self._is_blocking)
255+
if self.is_speculate_decoding:
256+
speculate_get_output(self.output_tokens, rank_id, self._is_blocking)
257+
else:
258+
get_output(self.output_tokens, rank_id, self._is_blocking)
369259

370260
if self.output_tokens[0, 0] == -2:
371261
continue
372262
self._process_batch_output()
373263
except Exception as e:
374264
model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))
375265

376-
def process_speculate_results(self):
377-
"""
378-
read tokens from paddle inference engine and process
379-
"""
380-
while self._is_running:
381-
try:
382-
rank_id = 0
383-
speculate_get_output(self.output_tokens, rank_id, self._is_blocking)
384-
385-
if self.output_tokens[0] == -2:
386-
continue
387-
self._process_speculate_output()
388-
except Exception as e:
389-
model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))
390-
391266
def stop(self):
392267
"""
393268
stop warm up thread

0 commit comments

Comments
 (0)