Skip to content

Commit 093614e

Browse files
committed
support stop_seqs
1 parent cbd7720 commit 093614e

2 files changed

Lines changed: 42 additions & 0 deletions

File tree

llm/server/server/data/processor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def process_request(self, request, max_seq_len=None):
143143
request["eos_token_ids"] = []
144144
request["eos_token_ids"].extend(get_eos_token_id(self.tokenizer, self.config.generation_config))
145145

146+
if "stop_seqs" not in request or (isinstance(request["stop_seqs"], (list, tuple)) and len(request["stop_seqs"]) == 0):
147+
self.update_stop_seq(request)
148+
146149
if "input_ids" not in request or \
147150
(isinstance(request["input_ids"], (list, tuple)) and len(request["input_ids"]) == 0):
148151
if "text" in request:
@@ -334,3 +337,19 @@ def get_pad_id(self):
334337
if isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer)) and not self.tokenizer.pad_token_id:
335338
return self.tokenizer.eos_token
336339
return self.tokenizer.pad_token_id
340+
341+
def update_stop_seq(self, request):
342+
"""
343+
Update stop sequences from request.
344+
"""
345+
stop_seqs = [[2], [100273]]
346+
for seq in request.get("stop_sequences", []):
347+
if seq != self._get_eos_token_id():
348+
stop_seqs.append(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(seq)))
349+
request["stop_seqs"], request["stop_seqs_len"] = self.pad_batch_data(
350+
stop_seqs,
351+
pad_id=-1,
352+
return_seq_len=True,
353+
return_array=False
354+
)
355+
data_processor_logger.debug(f"processed request: {request['stop_seqs'], request['stop_seqs_len']}")

llm/server/server/engine/infer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def __init__(self, args):
5454

5555
self.reduce_dialogue_repetition = int(os.environ.get("REDUCE_DIALOGUE_REPETITION", 0))
5656

57+
self.max_stop_seqs_num = int(os.getenv("MAX_STOP_SEQS_NUM", 5))
58+
self.stop_seqs_max_len = int(os.getenv("STOP_SEQS_MAX_LEN", 8))
59+
5760
self.nranks = dist.get_world_size()
5861
self.init_dist_env()
5962
self.rank = fleet.worker_index()
@@ -248,6 +251,13 @@ def init_inputs(self):
248251
self.share_inputs['free_list_len'] = paddle.full(
249252
shape=[1], fill_value=self.free_list_len, dtype="int32")
250253

254+
self.share_inputs['stop_seqs_len'] = paddle.full(shape=[max_stop_seqs_num,],
255+
fill_value=0,
256+
dtype="int32")
257+
self.share_inputs['stop_seqs'] = paddle.full(shape=[max_stop_seqs_num, stop_seqs_max_len],
258+
fill_value=-1,
259+
dtype="int64")
260+
251261
if self.reduce_dialogue_repetition:
252262
self.share_inputs["first_token_ids"] = paddle.full(
253263
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
@@ -300,6 +310,14 @@ def dy_input_preprocess(self, tasks):
300310
self.share_inputs["block_tables"][idx:idx + 1, :encoder_block_num] = np.array(
301311
task['block_tables'], dtype="int32")
302312

313+
if "stop_seqs_len" in task:
314+
stop_seqs_num = len(task["stop_seqs_len"])
315+
for i in range(stop_seqs_num, max_stop_seqs_num):
316+
task["stop_seqs_len"].append(0)
317+
share_inputs['stop_seqs_len'][:] = np.array(
318+
task["stop_seqs_len"], dtype="int32")
319+
share_inputs['stop_seqs'][:stop_seqs_num, :len(task['stop_seqs'][0])] = np.array(
320+
task["stop_seqs"], dtype="int64")
303321
def step_cuda(self, seq_lens_this_time):
304322
"""
305323
step cuda
@@ -486,6 +504,11 @@ def _init_predictor(self):
486504
config.switch_ir_optim(False)
487505
config.enable_use_gpu(100, device_id)
488506

507+
pir_flag = int(os.environ.get("FLAGS_enable_pir_api", 0))
508+
if pir_flag == 1:
509+
config.enable_new_executor()
510+
config.enable_new_ir()
511+
489512
# distributed config
490513
if self.mp_degree > 1:
491514
trainer_endpoints = fleet.worker_endpoints()

0 commit comments

Comments
 (0)