Skip to content

Commit 97e541e

Browse files
Merge pull request #2584 from ming1753/internet
support return_all_tokens & stop_seqs
2 parents 608d4be + c249b98 commit 97e541e

4 files changed

Lines changed: 104 additions & 1 deletion

File tree

llm/server/server/data/processor.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def process_request(self, request, max_seq_len=None):
143143
request["eos_token_ids"] = []
144144
request["eos_token_ids"].extend(get_eos_token_id(self.tokenizer, self.config.generation_config))
145145

146+
if "stop_seqs" not in request or (isinstance(request["stop_seqs"], (list, tuple)) and len(request["stop_seqs"]) == 0):
147+
self.update_stop_seq(request)
148+
146149
if "input_ids" not in request or \
147150
(isinstance(request["input_ids"], (list, tuple)) and len(request["input_ids"]) == 0):
148151
if "text" in request:
@@ -282,7 +285,7 @@ def _load_tokenizer(self):
282285
"""
283286
if self.config.use_hf_tokenizer:
284287
from transformers import AutoTokenizer
285-
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False, vocab_file=os.path.join(self.config.model_dir, "sentencepiece.bpe.model"))
288+
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False)
286289
else:
287290
from paddlenlp.transformers import AutoTokenizer
288291
return AutoTokenizer.from_pretrained(self.config.model_dir)
@@ -334,3 +337,43 @@ def get_pad_id(self):
334337
if isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer)) and not self.tokenizer.pad_token_id:
335338
return self.tokenizer.eos_token
336339
return self.tokenizer.pad_token_id
340+
341+
def pad_batch_data(self, insts, pad_id=0, return_seq_len=False, return_array=True, pad_style="right"):
342+
"""Pad the instances to the max sequence length in batch."""
343+
if len(insts) == 0:
344+
padded_insts = np.array([[]], dtype=np.int64) if return_array else [[]]
345+
if return_seq_len:
346+
seq_len = np.array([], dtype=np.int64) if return_array else []
347+
return padded_insts, seq_len
348+
return padded_insts
349+
350+
max_len = max(map(len, insts))
351+
if pad_style == "left":
352+
padded_insts = [[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts]
353+
else:
354+
padded_insts = [list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts]
355+
if return_array:
356+
padded_insts = np.array(padded_insts, dtype=np.int64).reshape([-1, max_len])
357+
358+
if return_seq_len:
359+
seq_len = [len(inst) for inst in insts]
360+
if return_array:
361+
seq_len = np.array(seq_len, dtype=np.int64).reshape(-1, 1)
362+
return padded_insts, seq_len
363+
return padded_insts
364+
365+
def update_stop_seq(self, request):
366+
"""
367+
Update stop sequences from request.
368+
"""
369+
stop_seqs = []
370+
for seq in request.get("stop_sequences", []):
371+
if seq != self.tokenizer.eos_token_id:
372+
stop_seqs.append(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(seq)))
373+
request["stop_seqs"], request["stop_seqs_len"] = self.pad_batch_data(
374+
stop_seqs,
375+
pad_id=-1,
376+
return_seq_len=True,
377+
return_array=False
378+
)
379+
data_processor_logger.debug(f"processed request: {request['stop_seqs'], request['stop_seqs_len']}")

llm/server/server/engine/infer.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ def __init__(self, args):
5252
self.args.num_attention_heads = self.get_value(self.model_cfg, ["num_attention_heads", "n_head"])
5353
self.args.hidden_size = self.model_cfg["hidden_size"]
5454

55+
self.reduce_dialogue_repetition = int(os.environ.get("REDUCE_DIALOGUE_REPETITION", 0))
56+
57+
self.max_stop_seqs_num = int(os.getenv("MAX_STOP_SEQS_NUM", 5))
58+
self.stop_seqs_max_len = int(os.getenv("STOP_SEQS_MAX_LEN", 8))
59+
5560
self.nranks = dist.get_world_size()
5661
self.init_dist_env()
5762
self.rank = fleet.worker_index()
@@ -246,6 +251,19 @@ def init_inputs(self):
246251
self.share_inputs['free_list_len'] = paddle.full(
247252
shape=[1], fill_value=self.free_list_len, dtype="int32")
248253

254+
self.share_inputs['stop_seqs_len'] = paddle.full(shape=[self.max_stop_seqs_num,],
255+
fill_value=0,
256+
dtype="int32")
257+
self.share_inputs['stop_seqs'] = paddle.full(shape=[self.max_stop_seqs_num, self.stop_seqs_max_len],
258+
fill_value=-1,
259+
dtype="int64")
260+
261+
if self.reduce_dialogue_repetition:
262+
self.share_inputs["first_token_ids"] = paddle.full(
263+
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
264+
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
265+
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
266+
249267
def dy_input_preprocess(self, tasks):
250268
"""
251269
dynamic insertion
@@ -279,6 +297,10 @@ def dy_input_preprocess(self, tasks):
279297
self.share_inputs['max_length'][idx:idx + 1] = max_dec_len
280298
self.share_inputs['stop_flags'][idx:idx + 1] = False
281299

300+
if self.reduce_dialogue_repetition:
301+
self.share_inputs['first_token_ids'][idx:idx + 1] = self.share_inputs['input_ids'][idx:idx + 1, :1]
302+
self.share_inputs["ori_seq_lens_encoder"][idx:idx + 1] = length
303+
282304
if "infer_seed" in task:
283305
self.share_inputs['infer_seed'][idx:idx + 1] = task['infer_seed']
284306

@@ -288,6 +310,14 @@ def dy_input_preprocess(self, tasks):
288310
self.share_inputs["block_tables"][idx:idx + 1, :encoder_block_num] = np.array(
289311
task['block_tables'], dtype="int32")
290312

313+
if "stop_seqs_len" in task:
314+
stop_seqs_num = len(task["stop_seqs_len"])
315+
for i in range(stop_seqs_num, self.max_stop_seqs_num):
316+
task["stop_seqs_len"].append(0)
317+
self.share_inputs['stop_seqs_len'][:] = np.array(
318+
task["stop_seqs_len"], dtype="int32")
319+
self.share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
320+
task["stop_seqs"], dtype="int64")
291321
def step_cuda(self, seq_lens_this_time):
292322
"""
293323
step cuda
@@ -474,6 +504,11 @@ def _init_predictor(self):
474504
config.switch_ir_optim(False)
475505
config.enable_use_gpu(100, device_id)
476506

507+
pir_flag = int(os.environ.get("FLAGS_enable_pir_api", 0))
508+
if pir_flag == 1:
509+
config.enable_new_executor()
510+
config.enable_new_ir()
511+
477512
# distributed config
478513
if self.mp_degree > 1:
479514
trainer_endpoints = fleet.worker_endpoints()

llm/server/server/http_server/api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class Req(BaseModel):
3131
req_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
3232
input_ids: Optional[List[int]] = None
3333
text: Optional[str] = None
34+
stop_sequences: Optional[List] = None
3435
messages: Optional[List] = None
3536
max_dec_len: Optional[int] = None
3637
seq_len: Optional[int] = None

llm/server/server/triton_server.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,35 @@ def _push_mode_sender_thread(self):
9898
except Exception as e:
9999
model_server_logger.error("Unexcepted error happend: {}, {}".format(e, str(traceback.format_exc())))
100100

101+
def _cache_special_tokens(self, batch_result):
102+
for i in range(len(batch_result)):
103+
is_end = batch_result[i].get("is_end", 0)
104+
token_ids = batch_result[i]["token_ids"]
105+
if is_end != 1:
106+
if batch_result[i]["req_id"] not in self.token_buffer:
107+
self.token_buffer[batch_result[i]["req_id"]] = list()
108+
self.score_buffer[batch_result[i]["req_id"]] = list()
109+
self.token_buffer[batch_result[i]["req_id"]].extend(token_ids)
110+
self.score_buffer[batch_result[i]["req_id"]].extend(batch_result[i].get("token_scores", []))
111+
batch_result[i]["token_ids"] = []
112+
if "token_scores" in batch_result[i]:
113+
batch_result[i]["token_scores"] = []
114+
else:
115+
if batch_result[i]["req_id"] in self.token_buffer:
116+
batch_result[i]["token_ids"] = self.token_buffer[batch_result[i]
117+
["req_id"]] + batch_result[i]["token_ids"]
118+
del self.token_buffer[batch_result[i]["req_id"]]
119+
if "token_scores" in batch_result[i]:
120+
batch_result[i]["token_scores"] = self.score_buffer[batch_result[i]
121+
["req_id"]] + batch_result[i]["token_scores"]
122+
del self.score_buffer[batch_result[i]["req_id"]]
123+
101124
def postprocess(self, batch_result, exist_finished_task=False):
102125
"""
103126
single postprocess for triton
104127
"""
105128
try:
129+
self._cache_special_tokens(batch_result)
106130
self.cached_generated_tokens.put(batch_result)
107131
except Exception as e:
108132
model_server_logger.info(

0 commit comments

Comments
 (0)