Skip to content

Commit e52155f

Browse files
v1.0 align accuracy
1 parent 47aacb5 commit e52155f

4 files changed

Lines changed: 62 additions & 48 deletions

File tree

llm/server/server/engine/config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,7 @@ def read_from_env(self):
9393
self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0))
9494

9595
# speculate decoding config
96-
self.speculate_method = str(env.get("SPECULATE_METHOD", None))
97-
self.speculate_max_draft_token_num = int(os.getenv("SPECULATE_MAX_DRAFT_TOKEN_NUM", 5))
98-
self.speculate_max_ngram_size = int(os.getenv("SPECULATE_MAX_NGRAM_SIZE", 2))
96+
self.speculate_method = str(os.getenv("SPECULATE_METHOD", None))
9997

10098
# infer config
10199
self.max_batch_size = int(env.get("BATCH_SIZE", 50))

llm/server/server/engine/infer.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,16 @@ def __init__(self, args):
6969
self.init_inputs()
7070

7171
# whether use speculate decoding
72-
if self.config.speculate_method is not None and self.config.speculate_method == "inference_with_reference":
73-
self.proposer = InferenceWithReferenceProposer(
74-
self.config.speculate_max_draft_token_num,
75-
self.config.speculate_max_ngram_size,
76-
self.args.max_batch_size)
72+
logger.info(f'speculate_method: {self.config.speculate_method}')
73+
if self.config.speculate_method is not None:
74+
if self.config.speculate_method == "inference_with_reference":
75+
self.proposer = InferenceWithReferenceProposer(
76+
self.model_cfg["speculate_max_draft_token_num"],
77+
self.model_cfg["speculate_max_ngram_size"],
78+
self.args.max_batch_size,
79+
self.args.max_seq_len)
80+
else:
81+
raise NotImplementedError(f'Not support {self.config.speculate_method}, only support inference_with_reference now.')
7782
else:
7883
self.proposer = None
7984

@@ -274,18 +279,17 @@ def init_inputs(self):
274279
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
275280
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
276281
# speculate decoding input
282+
logger.info(f'Speculative method: {self.config.speculate_method}')
277283
if self.config.speculate_method is not None:
278-
self.share_inputs["input_ids_cpu"] = paddle.full(
279-
shape=[self.args.max_batch_size, self.args.max_seq_len], fill_value=1, dtype='int64').cpu()
280284
self.share_inputs["accept_tokens"] = paddle.full(
281-
shape=[self.args.max_batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
285+
shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64"
282286
)
283287
self.share_inputs["accept_num"] = paddle.full(shape=[self.args.max_batch_size], fill_value=0, dtype="int32")
284288
self.share_inputs["draft_tokens"] = paddle.full(
285-
shape=[self.args.max_batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
289+
shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64"
286290
)
287291
self.share_inputs["actual_draft_token_num"] = paddle.full(
288-
shape=[self.args.max_batch_size], fill_value=self.config.speculate_max_draft_token_num, dtype="int32"
292+
shape=[self.args.max_batch_size], fill_value=self.model_cfg["speculate_max_draft_token_num"], dtype="int32"
289293
)
290294

291295
def dy_input_preprocess(self, tasks):
@@ -344,10 +348,8 @@ def dy_input_preprocess(self, tasks):
344348
task["stop_seqs"], dtype="int64")
345349
if self.proposer is not None:
346350
if self.config.speculate_method == "inference_with_reference":
347-
speculate_update_input_ids_cpu(self.share_inputs['input_ids_cpu'], task['input_ids'], idx, self.args.max_seq_len)
348-
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.config.speculate_max_draft_token_num + 1])
349-
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.config.speculate_max_draft_token_num])
350-
self.proposer.update(idx, length)
351+
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.model_cfg["speculate_max_draft_token_num"] + 1])
352+
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.model_cfg["speculate_max_draft_token_num"]])
351353

352354
def step_cuda(self, seq_lens_this_time):
353355
"""
@@ -381,7 +383,7 @@ def step_cuda(self, seq_lens_this_time):
381383
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
382384
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
383385
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id,
384-
self.config.speculate_max_draft_token_num)
386+
self.model_cfg["speculate_max_draft_token_num"])
385387

386388
def initialize_engine_ready_check_flag(self):
387389
"""
@@ -512,7 +514,6 @@ def run(self):
512514
if self.proposer is not None:
513515
logger.info("start run proposer")
514516
logger.info(f'before draft_tokens: {self.share_inputs["draft_tokens"]}')
515-
logger.info(f'before accept_tokens: {self.share_inputs["accept_tokens"]}')
516517

517518
self.proposer.run(
518519
self.share_inputs,
@@ -521,19 +522,19 @@ def run(self):
521522
)
522523
logger.info(f'after draft_tokens: {self.share_inputs["draft_tokens"]}')
523524
logger.info("finish run proposer")
524-
logger.info(f'input_ids: {self.share_inputs["input_ids"]}')
525-
logger.info(f'input_ids_cpu: {self.share_inputs["input_ids_cpu"]}')
526-
logger.info(f'seq_lens_this_time: {self.share_inputs["seq_lens_this_time"]}')
527-
logger.info(f'seq_lens_encoder: {self.share_inputs["seq_lens_encoder"]}')
528-
logger.info(f'seq_lens_decoder: {self.share_inputs["seq_lens_decoder"]}')
529-
logger.info(f'step_idx: {self.share_inputs["step_idx"]}')
530-
logger.info(f'next_tokens: {self.share_inputs["next_tokens"]}')
531-
logger.info(f'before block_tables: {self.share_inputs["block_tables"]}')
525+
# logger.info(f'input_ids: {self.share_inputs["input_ids"]}')
526+
# logger.info(f'input_ids_cpu: {self.share_inputs["input_ids_cpu"]}')
527+
# logger.info(f'seq_lens_this_time: {self.share_inputs["seq_lens_this_time"]}')
528+
# logger.info(f'seq_lens_encoder: {self.share_inputs["seq_lens_encoder"]}')
529+
# logger.info(f'seq_lens_decoder: {self.share_inputs["seq_lens_decoder"]}')
530+
# logger.info(f'step_idx: {self.share_inputs["step_idx"]}')
531+
# logger.info(f'next_tokens: {self.share_inputs["next_tokens"]}')
532+
# logger.info(f'before block_tables: {self.share_inputs["block_tables"]}')
532533

533534
self.infer_engine.predictor.run()
534535
logger.info(f'after accept_tokens: {self.share_inputs["accept_tokens"]}')
535536
logger.info(f'after accept_num: {self.share_inputs["accept_num"]}')
536-
logger.info(f'after block_tables: {self.share_inputs["block_tables"]}')
537+
# logger.info(f'after block_tables: {self.share_inputs["block_tables"]}')
537538

538539
self.share_inputs['infer_seed'].add_(infer_seed_increment)
539540
self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED

llm/server/server/engine/proposers.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from abc import ABC, abstractmethod
1717

1818
import paddle
19-
from paddlenlp_ops import ngram_match
2019

2120

2221
class Proposer(ABC):
@@ -43,7 +42,7 @@ class InferenceWithReferenceProposer(Proposer):
4342
It match tokens in the input and output as draft tokens.
4443
"""
4544

46-
def __init__(self, max_draft_token_num: int, max_ngram_size: int, max_batch_size: int):
45+
def __init__(self, max_draft_token_num: int, max_ngram_size: int, max_batch_size: int, max_seq_len: int, **kwargs):
4746
"""
4847
Args:
4948
max_draft_token_num (int):
@@ -54,34 +53,33 @@ def __init__(self, max_draft_token_num: int, max_ngram_size: int, max_batch_size
5453
The hyperparameter of n in the paper.
5554
max_batch_size (int):
5655
The maximum batch size.
56+
max_seq_len (int):
57+
The maximum sequence length.
5758
"""
5859
super().__init__()
5960
self.max_ngram_size = max_ngram_size
6061
self.input_ids_len = paddle.zeros(shape=[max_batch_size, 1], dtype="int64").cpu()
62+
self.input_ids_cpu = paddle.zeros(shape=[max_batch_size, max_seq_len], dtype="int64").cpu()
6163
self.max_batch_size = max_batch_size
6264
self.max_draft_token_num = max_draft_token_num
63-
# self.input_ids_cpu = paddle.full(shape=[max_batch_size, max_seq_len], fill_value=1, dtype="int64").cpu()
6465

65-
def update(self, bid: int, seq_len: int):
66-
"""
67-
Used when inserting a new query to update the length of the input_ids.
68-
"""
69-
self.input_ids_len[bid] = seq_len
70-
71-
def run(self, share_inputs: dict[str, paddle.Tensor], **kargs):
66+
def run(self, model_inputs: dict[str, paddle.Tensor], **kargs):
7267
"""
7368
Use ngram_match to get draft tokens from the input and output.
7469
"""
75-
draft_tokens = share_inputs["draft_tokens"].cpu()
70+
draft_tokens = model_inputs["draft_tokens"].cpu()
7671
seq_lens_this_time = kargs["seq_lens_this_time"].cpu()
77-
seq_lens_encoder = share_inputs["seq_lens_encoder"].cpu()
78-
seq_lens_decoder = share_inputs["seq_lens_decoder"].cpu()
72+
seq_lens_encoder = model_inputs["seq_lens_encoder"].cpu()
73+
seq_lens_decoder = model_inputs["seq_lens_decoder"].cpu()
74+
75+
from paddlenlp_ops import ngram_match
76+
7977
ngram_match(
80-
share_inputs["input_ids_cpu"],
78+
self.input_ids_cpu,
8179
self.input_ids_len.cpu(),
82-
share_inputs["pre_ids"].cpu(),
83-
share_inputs["step_idx"].cpu(),
84-
share_inputs["actual_draft_token_num"].cpu(),
80+
model_inputs["pre_ids"].cpu(),
81+
model_inputs["step_idx"].cpu(),
82+
model_inputs["actual_draft_token_num"].cpu(),
8583
draft_tokens,
8684
seq_lens_this_time,
8785
seq_lens_encoder,
@@ -90,6 +88,7 @@ def run(self, share_inputs: dict[str, paddle.Tensor], **kargs):
9088
self.max_ngram_size,
9189
self.max_draft_token_num,
9290
)
93-
share_inputs["draft_tokens"][:] = draft_tokens.cuda()
94-
share_inputs["seq_lens_encoder"][:] = seq_lens_encoder.cuda()
91+
92+
model_inputs["draft_tokens"][:] = draft_tokens.cuda()
93+
model_inputs["seq_lens_encoder"][:] = seq_lens_encoder.cuda()
9594
kargs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()

llm/server/server/engine/token_processor.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, cfg):
4141
self.tokens_counter = Counter()
4242

4343
if self.cfg.speculate_method is not None:
44-
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKEN_NUM + MAX_DRAFT_TOKEN_NUM + 2], fill_value=2, dtype="int64")
44+
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKEN_NUM + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64")
4545
else:
4646
self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")
4747
self.worker = None
@@ -302,6 +302,7 @@ def _process_speculate_output(self):
302302
batch post-processing function
303303
"""
304304
tokens = self.output_tokens.numpy()
305+
model_server_logger.info(f"speculate_result tokens: {self.output_tokens.tolist()}")
305306
batch = self.output_tokens[1]
306307
output_token_msg_id = int(self.output_tokens[0])
307308
accept_num = tokens[2 : batch + 2]
@@ -373,6 +374,21 @@ def process_sampling_results(self):
373374
except Exception as e:
374375
model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))
375376

377+
def process_speculate_results(self):
378+
"""
379+
read tokens from paddle inference engine and process
380+
"""
381+
while self._is_running:
382+
try:
383+
rank_id = 0
384+
speculate_get_output(self.output_tokens, rank_id, self._is_blocking)
385+
386+
if self.output_tokens[0] == -2:
387+
continue
388+
self._process_speculate_output()
389+
except Exception as e:
390+
model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))
391+
376392
def stop(self):
377393
"""
378394
stop warm up thread

0 commit comments

Comments
 (0)