Skip to content

Commit cbd7720

Browse files
committed
make return_all_tokens work
1 parent 577b7a7 commit cbd7720

1 file changed

Lines changed: 26 additions & 0 deletions

File tree

llm/server/server/triton_server.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,37 @@ def _push_mode_sender_thread(self):
9898
except Exception as e:
9999
model_server_logger.error("Unexcepted error happend: {}, {}".format(e, str(traceback.format_exc())))
100100

101+
def _cache_special_tokens(self, batch_result):
102+
for i in range(len(batch_result)):
103+
is_end = batch_result[i].get("is_end", 0)
104+
token_ids = batch_result[i]["token_ids"]
105+
return_all_tokens = batch_result[i].get("return_all_tokens", False)
106+
cache_special_token = False if is_end == 1 else (13 <= token_ids[0] <= 268)
107+
if is_end != 1 and (cache_special_token or return_all_tokens or self.cfg.disable_streaming):
108+
if batch_result[i]["req_id"] not in self.token_buffer:
109+
self.token_buffer[batch_result[i]["req_id"]] = list()
110+
self.score_buffer[batch_result[i]["req_id"]] = list()
111+
self.token_buffer[batch_result[i]["req_id"]].extend(token_ids)
112+
self.score_buffer[batch_result[i]["req_id"]].extend(batch_result[i].get("token_scores", []))
113+
batch_result[i]["token_ids"] = []
114+
if "token_scores" in batch_result[i]:
115+
batch_result[i]["token_scores"] = []
116+
else:
117+
if batch_result[i]["req_id"] in self.token_buffer:
118+
batch_result[i]["token_ids"] = self.token_buffer[batch_result[i]
119+
["req_id"]] + batch_result[i]["token_ids"]
120+
del self.token_buffer[batch_result[i]["req_id"]]
121+
if "token_scores" in batch_result[i]:
122+
batch_result[i]["token_scores"] = self.score_buffer[batch_result[i]
123+
["req_id"]] + batch_result[i]["token_scores"]
124+
del self.score_buffer[batch_result[i]["req_id"]]
125+
101126
def postprocess(self, batch_result, exist_finished_task=False):
102127
"""
103128
single postprocess for triton
104129
"""
105130
try:
131+
self._cache_special_tokens(batch_result)
106132
self.cached_generated_tokens.put(batch_result)
107133
except Exception as e:
108134
model_server_logger.info(

0 commit comments

Comments
 (0)