Skip to content

Commit 577b7a7

Browse files
committed
support reduce_dialogue_repetition
1 parent 608d4be commit 577b7a7

2 files changed

Lines changed: 13 additions & 1 deletion

File tree

llm/server/server/data/processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def _load_tokenizer(self):
282282
"""
283283
if self.config.use_hf_tokenizer:
284284
from transformers import AutoTokenizer
285-
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False, vocab_file=os.path.join(self.config.model_dir, "sentencepiece.bpe.model"))
285+
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False)
286286
else:
287287
from paddlenlp.transformers import AutoTokenizer
288288
return AutoTokenizer.from_pretrained(self.config.model_dir)

llm/server/server/engine/infer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def __init__(self, args):
5252
self.args.num_attention_heads = self.get_value(self.model_cfg, ["num_attention_heads", "n_head"])
5353
self.args.hidden_size = self.model_cfg["hidden_size"]
5454

55+
self.reduce_dialogue_repetition = int(os.environ.get("REDUCE_DIALOGUE_REPETITION", 0))
56+
5557
self.nranks = dist.get_world_size()
5658
self.init_dist_env()
5759
self.rank = fleet.worker_index()
@@ -246,6 +248,12 @@ def init_inputs(self):
246248
self.share_inputs['free_list_len'] = paddle.full(
247249
shape=[1], fill_value=self.free_list_len, dtype="int32")
248250

251+
if self.reduce_dialogue_repetition:
252+
self.share_inputs["first_token_ids"] = paddle.full(
253+
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
254+
self.share_inputs["ori_seq_lens_encoder"] = paddle.full(
255+
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
256+
249257
def dy_input_preprocess(self, tasks):
250258
"""
251259
dynamic insertion
@@ -279,6 +287,10 @@ def dy_input_preprocess(self, tasks):
279287
self.share_inputs['max_length'][idx:idx + 1] = max_dec_len
280288
self.share_inputs['stop_flags'][idx:idx + 1] = False
281289

290+
if self.reduce_dialogue_repetition:
291+
self.share_inputs['first_token_ids'][idx:idx + 1] = self.share_inputs['input_ids'][idx:idx + 1, :1]
292+
self.share_inputs["ori_seq_lens_encoder"][idx:idx + 1] = length
293+
282294
if "infer_seed" in task:
283295
self.share_inputs['infer_seed'][idx:idx + 1] = task['infer_seed']
284296

0 commit comments

Comments
 (0)