@@ -98,11 +98,37 @@ def _push_mode_sender_thread(self):
9898 except Exception as e :
9999 model_server_logger .error ("Unexcepted error happend: {}, {}" .format (e , str (traceback .format_exc ())))
100100
101+ def _cache_special_tokens (self , batch_result ):
102+ for i in range (len (batch_result )):
103+ is_end = batch_result [i ].get ("is_end" , 0 )
104+ token_ids = batch_result [i ]["token_ids" ]
105+ return_all_tokens = batch_result [i ].get ("return_all_tokens" , False )
106+ cache_special_token = False if is_end == 1 else (13 <= token_ids [0 ] <= 268 )
107+ if is_end != 1 and (cache_special_token or return_all_tokens or self .cfg .disable_streaming ):
108+ if batch_result [i ]["req_id" ] not in self .token_buffer :
109+ self .token_buffer [batch_result [i ]["req_id" ]] = list ()
110+ self .score_buffer [batch_result [i ]["req_id" ]] = list ()
111+ self .token_buffer [batch_result [i ]["req_id" ]].extend (token_ids )
112+ self .score_buffer [batch_result [i ]["req_id" ]].extend (batch_result [i ].get ("token_scores" , []))
113+ batch_result [i ]["token_ids" ] = []
114+ if "token_scores" in batch_result [i ]:
115+ batch_result [i ]["token_scores" ] = []
116+ else :
117+ if batch_result [i ]["req_id" ] in self .token_buffer :
118+ batch_result [i ]["token_ids" ] = self .token_buffer [batch_result [i ]
119+ ["req_id" ]] + batch_result [i ]["token_ids" ]
120+ del self .token_buffer [batch_result [i ]["req_id" ]]
121+ if "token_scores" in batch_result [i ]:
122+ batch_result [i ]["token_scores" ] = self .score_buffer [batch_result [i ]
123+ ["req_id" ]] + batch_result [i ]["token_scores" ]
124+ del self .score_buffer [batch_result [i ]["req_id" ]]
125+
101126 def postprocess (self , batch_result , exist_finished_task = False ):
102127 """
103128 single postprocess for triton
104129 """
105130 try :
131+ self ._cache_special_tokens (batch_result )
106132 self .cached_generated_tokens .put (batch_result )
107133 except Exception as e :
108134 model_server_logger .info (
0 commit comments