@@ -265,15 +265,21 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
265265 else :
266266 position_ids = None
267267
268- enable_thinking = request .get ("enable_thinking" , True )
269- enable_thinking = enable_thinking if enable_thinking is not None else True
270- self .share_inputs ["enable_thinking" ][:] = enable_thinking
271- self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 1 if enable_thinking else 0
272- self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = request .get ("reasoning_max_tokens" , 2048 )
273268 self .share_inputs ["rope_emb" ][idx : idx + 1 , :] = self .prepare_rope3d (
274269 position_ids , request .get ("max_tokens" , 2048 )
275270 )
276271
272+ if request .get ("enable_thinking" , False ) and request .get ("reasoning_max_tokens" ) is not None :
273+ # Enable thinking
274+ self .share_inputs ["enable_thinking" ][:] = True
275+ self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 1
276+ self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = request .get ("reasoning_max_tokens" )
277+ else :
278+ # Disable thinking
279+ self .share_inputs ["enable_thinking" ][:] = False
280+ self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 0
281+ self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = 0
282+
277283 if isinstance (request .prompt_token_ids , np .ndarray ):
278284 prompt_token_ids = request .prompt_token_ids .tolist ()
279285 else :
@@ -495,16 +501,22 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
495501 self .share_inputs ["prompt_lens" ][idx : idx + 1 ] = length
496502
497503 if self .enable_mm :
498- enable_thinking = request .get ("enable_thinking" , True )
499- enable_thinking = enable_thinking if enable_thinking is not None else True
500- self .share_inputs ["enable_thinking" ][:] = enable_thinking
501- self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 1 if enable_thinking else 0
502- self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = request .get ("reasoning_max_tokens" , 2048 )
503504 self .share_inputs ["rope_emb" ][idx : idx + 1 , :] = self .prepare_rope3d (
504505 position_ids , request .get ("max_tokens" , 2048 )
505506 )
506507 self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = 0
507508
509+ if request .get ("enable_thinking" , False ) and request .get ("reasoning_max_tokens" ) is not None :
510+ # Enable thinking
511+ self .share_inputs ["enable_thinking" ][:] = True
512+ self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 1
513+ self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = request .get ("reasoning_max_tokens" )
514+ else :
515+ # Disable thinking
516+ self .share_inputs ["enable_thinking" ][:] = False
517+ self .share_inputs ["need_think_end" ][idx : idx + 1 , :] = 0
518+ self .share_inputs ["reasoning_index" ][idx : idx + 1 , :] = 0
519+
508520 def get_attr_from_request (request , attr , default_value = None ):
509521 res = request .get (attr , default_value )
510522 if res is not None :
@@ -735,6 +747,11 @@ def _init_share_inputs(self, max_num_seqs: int):
735747 # Initialize rotary position embedding
736748 tmp_position_ids = paddle .arange (self .parallel_config .max_model_len ).reshape ((1 , - 1 ))
737749
750+ # Initialize thinking related buffers
751+ self .share_inputs ["need_think_end" ] = paddle .full (shape = [max_num_seqs , 1 ], fill_value = 0 , dtype = "int32" )
752+ self .share_inputs ["enable_thinking" ] = paddle .full (shape = [1 ], fill_value = False , dtype = "bool" )
753+ self .share_inputs ["reasoning_index" ] = paddle .full (shape = [max_num_seqs , 1 ], fill_value = 0 , dtype = "int32" )
754+
738755 # TODO(gongshaotian): move to models
739756 if not self .enable_mm :
740757 self .share_inputs ["rope_emb" ] = get_rope (
@@ -827,11 +844,6 @@ def _init_share_inputs(self, max_num_seqs: int):
827844 dtype = "float32" ,
828845 )
829846 self .share_inputs ["image_features" ] = None
830- self .share_inputs ["need_think_end" ] = paddle .full (shape = [max_num_seqs , 1 ], fill_value = 0 , dtype = "int32" )
831- self .share_inputs ["enable_thinking" ] = paddle .full (
832- shape = [1 ], fill_value = ("ernie" in self .model_config .model_type ), dtype = "bool"
833- )
834- self .share_inputs ["reasoning_index" ] = paddle .full (shape = [max_num_seqs , 1 ], fill_value = 0 , dtype = "int32" )
835847
836848 def _prepare_inputs (self ) -> None :
837849 """Prepare the model inputs"""
@@ -1220,10 +1232,10 @@ def _dummy_run(
12201232 ),
12211233 accept_tokens = (self .share_inputs ["accept_tokens" ] if self .speculative_decoding else None ),
12221234 accept_num = (self .share_inputs ["accept_num" ] if self .speculative_decoding else None ),
1223- enable_thinking = ( self .share_inputs ["enable_thinking" ] if self . enable_mm else None ) ,
1224- think_end_id = ( getattr ( self .model_config , "think_end_id" , - 1 ) if self . enable_mm else - 1 ) ,
1225- need_think_end = ( self .share_inputs ["need_think_end" ] if self . enable_mm else None ) ,
1226- reasoning_index = ( self .share_inputs ["reasoning_index" ] if self . enable_mm else None ) ,
1235+ enable_thinking = self .share_inputs ["enable_thinking" ],
1236+ think_end_id = self .model_config . think_end_id ,
1237+ need_think_end = self .share_inputs ["need_think_end" ],
1238+ reasoning_index = self .share_inputs ["reasoning_index" ],
12271239 stop_token_ids = self .share_inputs ["stop_seqs" ],
12281240 stop_seqs_len = self .share_inputs ["stop_seqs_len" ],
12291241 )
@@ -1515,10 +1527,10 @@ class at the server level, which is too granular for ModelRunner.
15151527 ),
15161528 accept_tokens = (self .share_inputs ["accept_tokens" ] if self .speculative_decoding else None ),
15171529 accept_num = (self .share_inputs ["accept_num" ] if self .speculative_decoding else None ),
1518- enable_thinking = ( self .share_inputs ["enable_thinking" ] if self . enable_mm else None ) ,
1519- think_end_id = ( getattr ( self .model_config , "think_end_id" , - 1 ) if self . enable_mm else - 1 ) ,
1520- need_think_end = ( self .share_inputs ["need_think_end" ][:num_running_requests ] if self . enable_mm else None ) ,
1521- reasoning_index = ( self .share_inputs ["reasoning_index" ][:num_running_requests ] if self . enable_mm else None ) ,
1530+ enable_thinking = self .share_inputs ["enable_thinking" ],
1531+ think_end_id = self .model_config . think_end_id ,
1532+ need_think_end = self .share_inputs ["need_think_end" ][:num_running_requests ],
1533+ reasoning_index = self .share_inputs ["reasoning_index" ][:num_running_requests ],
15221534 stop_token_ids = self .share_inputs ["stop_seqs" ],
15231535 stop_seqs_len = self .share_inputs ["stop_seqs_len" ],
15241536 )
0 commit comments