Skip to content

Commit 47aacb5

Browse files
add speculate_decoding framework
1 parent 97e541e commit 47aacb5

4 files changed

Lines changed: 331 additions & 18 deletions

File tree

llm/server/server/engine/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ def read_from_env(self):
9191
self.block_size = int(env.get("BLOCK_SIZE", 64))
9292
self.use_cache_kv_int8 = int(os.getenv("USE_CACHE_KV_INT8", 0))
9393
self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0))
94+
95+
# speculate decoding config
96+
self.speculate_method = str(env.get("SPECULATE_METHOD", None))
97+
self.speculate_max_draft_token_num = int(os.getenv("SPECULATE_MAX_DRAFT_TOKEN_NUM", 5))
98+
self.speculate_max_ngram_size = int(os.getenv("SPECULATE_MAX_NGRAM_SIZE", 2))
9499

95100
# infer config
96101
self.max_batch_size = int(env.get("BATCH_SIZE", 50))

llm/server/server/engine/infer.py

Lines changed: 91 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626
import paddle.distributed as dist
2727
import paddle.distributed.fleet as fleet
2828
from paddlenlp.trl.llm_utils import get_rotary_position_embedding
29-
from paddlenlp_ops import step_paddle
29+
from paddlenlp_ops import step_paddle, speculate_step_paddle
3030
from server.data.processor import DataProcessor
3131
from server.engine.config import Config
32+
from server.engine.proposers import InferenceWithReferenceProposer
3233
from server.utils import get_logger
3334
from task_queue_manager import TaskQueueManager
3435

@@ -67,6 +68,15 @@ def __init__(self, args):
6768
self.cache_kvs = {}
6869
self.init_inputs()
6970

71+
# whether use speculate decoding
72+
if self.config.speculate_method is not None and self.config.speculate_method == "inference_with_reference":
73+
self.proposer = InferenceWithReferenceProposer(
74+
self.config.speculate_max_draft_token_num,
75+
self.config.speculate_max_ngram_size,
76+
self.args.max_batch_size)
77+
else:
78+
self.proposer = None
79+
7080
self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_port)
7181

7282
model_rank_path = os.path.join(self.args.model_dir, f"rank_{self.rank}")
@@ -263,6 +273,20 @@ def init_inputs(self):
263273
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
264274
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
265275
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
276+
# speculate decoding input
277+
if self.config.speculate_method is not None:
278+
self.share_inputs["input_ids_cpu"] = paddle.full(
279+
shape=[self.args.max_batch_size, self.args.max_seq_len], fill_value=1, dtype='int64').cpu()
280+
self.share_inputs["accept_tokens"] = paddle.full(
281+
shape=[self.args.max_batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
282+
)
283+
self.share_inputs["accept_num"] = paddle.full(shape=[self.args.max_batch_size], fill_value=0, dtype="int32")
284+
self.share_inputs["draft_tokens"] = paddle.full(
285+
shape=[self.args.max_batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
286+
)
287+
self.share_inputs["actual_draft_token_num"] = paddle.full(
288+
shape=[self.args.max_batch_size], fill_value=self.config.speculate_max_draft_token_num, dtype="int32"
289+
)
266290

267291
def dy_input_preprocess(self, tasks):
268292
"""
@@ -318,23 +342,46 @@ def dy_input_preprocess(self, tasks):
318342
task["stop_seqs_len"], dtype="int32")
319343
self.share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
320344
task["stop_seqs"], dtype="int64")
345+
if self.proposer is not None:
346+
if self.config.speculate_method == "inference_with_reference":
347+
speculate_update_input_ids_cpu(self.share_inputs['input_ids_cpu'], task['input_ids'], idx, self.args.max_seq_len)
348+
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.config.speculate_max_draft_token_num + 1])
349+
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.config.speculate_max_draft_token_num])
350+
self.proposer.update(idx, length)
351+
321352
def step_cuda(self, seq_lens_this_time):
322353
"""
323354
step cuda
324355
"""
325-
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
326-
self.share_inputs['step_seq_lens_encoder'],
327-
self.share_inputs['seq_lens_encoder'],
328-
self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"],
329-
self.share_inputs['encoder_block_lens'],
330-
self.share_inputs["is_block_step"], self.share_inputs['step_block_list'],
331-
self.share_inputs['step_lens'], self.share_inputs['recover_block_list'],
332-
self.share_inputs['recover_lens'], self.share_inputs['need_block_list'],
333-
self.share_inputs['need_block_len'], self.share_inputs['used_list_len'],
334-
self.share_inputs['free_list'], self.share_inputs['free_list_len'],
335-
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
336-
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
337-
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id)
356+
if self.config.speculate_method is None:
357+
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
358+
self.share_inputs['step_seq_lens_encoder'],
359+
self.share_inputs['seq_lens_encoder'],
360+
self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"],
361+
self.share_inputs['encoder_block_lens'],
362+
self.share_inputs["is_block_step"], self.share_inputs['step_block_list'],
363+
self.share_inputs['step_lens'], self.share_inputs['recover_block_list'],
364+
self.share_inputs['recover_lens'], self.share_inputs['need_block_list'],
365+
self.share_inputs['need_block_len'], self.share_inputs['used_list_len'],
366+
self.share_inputs['free_list'], self.share_inputs['free_list_len'],
367+
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
368+
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
369+
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id)
370+
else:
371+
speculate_step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
372+
self.share_inputs['step_seq_lens_encoder'],
373+
self.share_inputs['seq_lens_encoder'],
374+
self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"],
375+
self.share_inputs['encoder_block_lens'],
376+
self.share_inputs["is_block_step"], self.share_inputs['step_block_list'],
377+
self.share_inputs['step_lens'], self.share_inputs['recover_block_list'],
378+
self.share_inputs['recover_lens'], self.share_inputs['need_block_list'],
379+
self.share_inputs['need_block_len'], self.share_inputs['used_list_len'],
380+
self.share_inputs['free_list'], self.share_inputs['free_list_len'],
381+
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
382+
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
383+
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id,
384+
self.config.speculate_max_draft_token_num)
338385

339386
def initialize_engine_ready_check_flag(self):
340387
"""
@@ -434,6 +481,9 @@ def run(self):
434481
self.share_inputs["seq_lens_this_time"][:real_bsz] = seq_lens_this_time
435482

436483
tasks, read_finish = self.infer_queue.get()
484+
logger.info(f'tasks: {tasks}')
485+
logger.info(f'read_finish: {read_finish}')
486+
437487
if read_finish:
438488
flag_broadcast_array[0] = 0
439489

@@ -442,7 +492,7 @@ def run(self):
442492
real_bsz = int(bsz)
443493
req_dicts.extend(req_dict)
444494
logger.info(
445-
f'rank: {self.rank}, real_bsz: {real_bsz}, query_num: {len(req_dicts)}'
495+
f'req_dict: {req_dict} rank: {self.rank}, real_bsz: {real_bsz}, query_num: {len(req_dicts)}'
446496
)
447497

448498
self.dy_input_preprocess(req_dicts)
@@ -459,10 +509,36 @@ def run(self):
459509
time.sleep(0.001)
460510
continue
461511

512+
if self.proposer is not None:
513+
logger.info("start run proposer")
514+
logger.info(f'before draft_tokens: {self.share_inputs["draft_tokens"]}')
515+
logger.info(f'before accept_tokens: {self.share_inputs["accept_tokens"]}')
516+
517+
self.proposer.run(
518+
self.share_inputs,
519+
real_batch_size=self.args.max_batch_size,
520+
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
521+
)
522+
logger.info(f'after draft_tokens: {self.share_inputs["draft_tokens"]}')
523+
logger.info("finish run proposer")
524+
logger.info(f'input_ids: {self.share_inputs["input_ids"]}')
525+
logger.info(f'input_ids_cpu: {self.share_inputs["input_ids_cpu"]}')
526+
logger.info(f'seq_lens_this_time: {self.share_inputs["seq_lens_this_time"]}')
527+
logger.info(f'seq_lens_encoder: {self.share_inputs["seq_lens_encoder"]}')
528+
logger.info(f'seq_lens_decoder: {self.share_inputs["seq_lens_decoder"]}')
529+
logger.info(f'step_idx: {self.share_inputs["step_idx"]}')
530+
logger.info(f'next_tokens: {self.share_inputs["next_tokens"]}')
531+
logger.info(f'before block_tables: {self.share_inputs["block_tables"]}')
532+
462533
self.infer_engine.predictor.run()
534+
logger.info(f'after accept_tokens: {self.share_inputs["accept_tokens"]}')
535+
logger.info(f'after accept_num: {self.share_inputs["accept_num"]}')
536+
logger.info(f'after block_tables: {self.share_inputs["block_tables"]}')
537+
463538
self.share_inputs['infer_seed'].add_(infer_seed_increment)
464539
self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED
465540
if self.free_list_len > 0:
541+
logger.info(f'free_list_len > 0')
466542
self.step_cuda(seq_lens_this_time)
467543

468544

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from abc import ABC, abstractmethod
17+
18+
import paddle
19+
from paddlenlp_ops import ngram_match
20+
21+
22+
class Proposer(ABC):
23+
"""
24+
Abstract base class for all proposers that can be used in the speculative decoding framework.
25+
The subclasses of this class must implement the run method to get the draft tokens that are
26+
generated by the proposer.
27+
"""
28+
29+
def __init__(self, **kwargs):
30+
pass
31+
32+
@abstractmethod
33+
def run(self, model_inputs: dict[str, paddle.Tensor], **kargs):
34+
"""
35+
Get the draft tokens that are generated by the proposer.
36+
"""
37+
raise NotImplementedError()
38+
39+
40+
class InferenceWithReferenceProposer(Proposer):
41+
"""
42+
InferenceWithReference(https://arxiv.org/pdf/2304.04487) is one of the speculative decoding method.
43+
It match tokens in the input and output as draft tokens.
44+
"""
45+
46+
def __init__(self, max_draft_token_num: int, max_ngram_size: int, max_batch_size: int):
47+
"""
48+
Args:
49+
max_draft_token_num (int):
50+
Maximum number of tokens a proposer can generate at one time.
51+
The hyperparameter of k in the paper.
52+
max_ngram_size (int):
53+
The maximum size of the window used to match inputs and outputs.
54+
The hyperparameter of n in the paper.
55+
max_batch_size (int):
56+
The maximum batch size.
57+
"""
58+
super().__init__()
59+
self.max_ngram_size = max_ngram_size
60+
self.input_ids_len = paddle.zeros(shape=[max_batch_size, 1], dtype="int64").cpu()
61+
self.max_batch_size = max_batch_size
62+
self.max_draft_token_num = max_draft_token_num
63+
# self.input_ids_cpu = paddle.full(shape=[max_batch_size, max_seq_len], fill_value=1, dtype="int64").cpu()
64+
65+
def update(self, bid: int, seq_len: int):
66+
"""
67+
Used when inserting a new query to update the length of the input_ids.
68+
"""
69+
self.input_ids_len[bid] = seq_len
70+
71+
def run(self, share_inputs: dict[str, paddle.Tensor], **kargs):
72+
"""
73+
Use ngram_match to get draft tokens from the input and output.
74+
"""
75+
draft_tokens = share_inputs["draft_tokens"].cpu()
76+
seq_lens_this_time = kargs["seq_lens_this_time"].cpu()
77+
seq_lens_encoder = share_inputs["seq_lens_encoder"].cpu()
78+
seq_lens_decoder = share_inputs["seq_lens_decoder"].cpu()
79+
ngram_match(
80+
share_inputs["input_ids_cpu"],
81+
self.input_ids_len.cpu(),
82+
share_inputs["pre_ids"].cpu(),
83+
share_inputs["step_idx"].cpu(),
84+
share_inputs["actual_draft_token_num"].cpu(),
85+
draft_tokens,
86+
seq_lens_this_time,
87+
seq_lens_encoder,
88+
seq_lens_decoder,
89+
kargs["real_batch_size"],
90+
self.max_ngram_size,
91+
self.max_draft_token_num,
92+
)
93+
share_inputs["draft_tokens"][:] = draft_tokens.cuda()
94+
share_inputs["seq_lens_encoder"][:] = seq_lens_encoder.cuda()
95+
kargs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()

0 commit comments

Comments
 (0)