Skip to content

Commit 934e8d8

Browse files
update
1 parent ce3c09d commit 934e8d8

2 files changed

Lines changed: 59 additions & 42 deletions

File tree

llm/server/server/engine/config.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from paddlenlp.generation import GenerationConfig
2121
from server.utils import model_server_logger
22-
22+
from dataclasses import dataclass
2323

2424
class Config:
2525
"""
@@ -203,6 +203,26 @@ def get_model_config(self):
203203
model_config_json = json.load(open(self.model_config_path, 'r', encoding='utf-8'))
204204
return model_config_json
205205

206+
def get_speculate_config(self):
207+
"""
208+
get speculate_decoding related config
209+
210+
Returns:
211+
SpeculateConfig: the speculate related config
212+
"""
213+
speculate_config = SpeculateConfig()
214+
if self.model_cfg.get("speculate_method") is not None:
215+
speculate_config.speculate_method = self.model_cfg["speculate_method"]
216+
speculate_config.speculate_max_draft_token_num = self.model_cfg[
217+
"speculate_max_draft_token_num"]
218+
speculate_config.speculate_max_ngram_size = self.model_cfg[
219+
"speculate_max_ngram_size"]
220+
221+
if speculate_config.speculate_method is not in ["none", "inference_with_reference"]:
222+
model_server_logger.error(f"Unsupport speculate method: {speculate_config.speculate_method}")
223+
224+
return speculate_config
225+
206226
def read_from_config(self):
207227
"""
208228
reset model config from json file
@@ -234,3 +254,10 @@ def get_unique_name(self, name):
234254

235255
def __str__(self) -> str:
236256
return json.dumps(self.__dict__, indent=4)
257+
258+
259+
@dataclass
260+
class SpeculateConfig:
261+
speculate_method: str = None
262+
speculate_max_draft_token_num: int = 1
263+
speculate_max_ngram_size: int = 1

llm/server/server/engine/infer.py

Lines changed: 31 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import paddle.distributed as dist
2727
import paddle.distributed.fleet as fleet
2828
from paddlenlp.trl.llm_utils import get_rotary_position_embedding
29-
from paddlenlp_ops import step_paddle, speculate_step_paddle
29+
from paddlenlp_ops import step_paddle
3030
from server.data.processor import DataProcessor
3131
from server.engine.config import Config
3232
from paddlenlp.experimental.transformers import InferenceWithReferenceProposer
@@ -47,7 +47,8 @@ def __init__(self, args):
4747

4848
self.config = Config()
4949
self.model_cfg = self.config.get_model_config()
50-
self.is_speculate_decoding = self.model_cfg.get("speculate_method") is not None
50+
self.speculate_config = self.config.get_speculate_config()
51+
self.is_speculate_decoding = self.speculate_config.speculate_method is not None
5152
self.format_print_configuration()
5253

5354
self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"])
@@ -70,15 +71,13 @@ def __init__(self, args):
7071
self.init_inputs()
7172

7273
if self.is_speculate_decoding:
73-
logger.info(f'Using speculating decoding, method: {self.model_cfg["speculate_method"]}.')
74-
if self.model_cfg["speculate_method"] == "inference_with_reference":
74+
logger.info(f'Using speculating decoding, method: {self.speculate_config.speculate_method}.')
75+
if self.speculate_config.speculate_method == "inference_with_reference":
7576
self.proposer = InferenceWithReferenceProposer(
76-
self.model_cfg["speculate_max_draft_token_num"],
77-
self.model_cfg["speculate_max_ngram_size"],
77+
self.speculate_config.speculate_max_draft_token_num,
78+
self.speculate_config.speculate_max_ngram_size,
7879
self.args.max_batch_size,
7980
self.args.max_seq_len)
80-
else:
81-
raise NotImplementedError(f'Not support {self.model_cfg["speculate_method"]}, only support inference_with_reference now.')
8281
else:
8382
self.proposer = None
8483

@@ -281,14 +280,14 @@ def init_inputs(self):
281280
# speculate decoding input
282281
if self.is_speculate_decoding:
283282
self.share_inputs["accept_tokens"] = paddle.full(
284-
shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64"
283+
shape=[self.args.max_batch_size, self.speculate_config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
285284
)
286285
self.share_inputs["accept_num"] = paddle.full(shape=[self.args.max_batch_size], fill_value=0, dtype="int32")
287286
self.share_inputs["draft_tokens"] = paddle.full(
288-
shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64"
287+
shape=[self.args.max_batch_size, self.speculate_config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
289288
)
290289
self.share_inputs["actual_draft_token_num"] = paddle.full(
291-
shape=[self.args.max_batch_size], fill_value=self.model_cfg["speculate_max_draft_token_num"], dtype="int32"
290+
shape=[self.args.max_batch_size], fill_value=self.speculate_config.speculate_max_draft_token_num, dtype="int32"
292291
)
293292

294293
def dy_input_preprocess(self, tasks):
@@ -347,42 +346,33 @@ def dy_input_preprocess(self, tasks):
347346
task["stop_seqs"], dtype="int64")
348347

349348
if self.is_speculate_decoding:
350-
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.model_cfg["speculate_max_draft_token_num"] + 1])
351-
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.model_cfg["speculate_max_draft_token_num"]])
349+
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.speculate_config.speculate_max_draft_token_num + 1])
350+
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.speculate_config.speculate_max_draft_token_num])
352351

353352
def step_cuda(self, seq_lens_this_time):
354353
"""
355354
step cuda
356355
"""
357-
if not self.is_speculate_decoding:
358-
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
359-
self.share_inputs['step_seq_lens_encoder'],
360-
self.share_inputs['seq_lens_encoder'],
361-
self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"],
362-
self.share_inputs['encoder_block_lens'],
363-
self.share_inputs["is_block_step"], self.share_inputs['step_block_list'],
364-
self.share_inputs['step_lens'], self.share_inputs['recover_block_list'],
365-
self.share_inputs['recover_lens'], self.share_inputs['need_block_list'],
366-
self.share_inputs['need_block_len'], self.share_inputs['used_list_len'],
367-
self.share_inputs['free_list'], self.share_inputs['free_list_len'],
368-
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
369-
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
370-
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id)
356+
# whether speculate decoding
357+
if self.is_speculate_decoding:
358+
speculate_step_token_num = self.speculate_config.speculate_max_draft_token_num + 1
371359
else:
372-
speculate_step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
373-
self.share_inputs['step_seq_lens_encoder'],
374-
self.share_inputs['seq_lens_encoder'],
375-
self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"],
376-
self.share_inputs['encoder_block_lens'],
377-
self.share_inputs["is_block_step"], self.share_inputs['step_block_list'],
378-
self.share_inputs['step_lens'], self.share_inputs['recover_block_list'],
379-
self.share_inputs['recover_lens'], self.share_inputs['need_block_list'],
380-
self.share_inputs['need_block_len'], self.share_inputs['used_list_len'],
381-
self.share_inputs['free_list'], self.share_inputs['free_list_len'],
382-
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
383-
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
384-
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id,
385-
self.model_cfg["speculate_max_draft_token_num"])
360+
speculate_step_token_num = 0
361+
362+
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
363+
self.share_inputs['step_seq_lens_encoder'],
364+
self.share_inputs['seq_lens_encoder'],
365+
self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"],
366+
self.share_inputs['encoder_block_lens'],
367+
self.share_inputs["is_block_step"], self.share_inputs['step_block_list'],
368+
self.share_inputs['step_lens'], self.share_inputs['recover_block_list'],
369+
self.share_inputs['recover_lens'], self.share_inputs['need_block_list'],
370+
self.share_inputs['need_block_len'], self.share_inputs['used_list_len'],
371+
self.share_inputs['free_list'], self.share_inputs['free_list_len'],
372+
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
373+
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
374+
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id
375+
speculate_step_token_num)
386376

387377
def initialize_engine_ready_check_flag(self):
388378
"""

0 commit comments

Comments
 (0)