@@ -432,6 +432,11 @@ async def chat_completion_stream_generator(
432432 delta = delta_message ,
433433 logprobs = logprobs_res ,
434434 draft_logprobs = draft_logprobs_res ,
435+ sampling_mask = (
436+ self ._make_sampling_mask_list (output ["sampling_mask" ])
437+ if output .get ("sampling_mask" ) is not None
438+ else None
439+ ),
435440 arrival_time = arrival_time ,
436441 speculate_metrics = output_speculate_metrics ,
437442 )
@@ -577,6 +582,7 @@ async def chat_completion_full_generator(
577582 decoder_base_url = self .tokenizer_base_url ,
578583 )
579584 prompt_logprobs_res_list = [[] for _ in range (num_choices )]
585+ sampling_mask_list = [[] for _ in range (num_choices )]
580586 speculate_metrics = [None for _ in range (num_choices )]
581587 choices = []
582588 while num_choices > 0 :
@@ -656,6 +662,9 @@ async def chat_completion_full_generator(
656662 )
657663 if prompt_logprobs_res :
658664 prompt_logprobs_res_list [idx ].extend (clamp_prompt_logprobs (prompt_logprobs_res ))
665+ output_sampling_mask = output .get ("sampling_mask" , None )
666+ if output_sampling_mask is not None :
667+ sampling_mask_list [idx ].append (self ._make_sampling_mask_list (output_sampling_mask ))
659668 speculate_metrics [idx ] = data ["metrics" ].get ("speculate_metrics" , None )
660669 if data ["finished" ]:
661670 trace_carrier = data .get ("trace_carrier" )
@@ -691,6 +700,7 @@ async def chat_completion_full_generator(
691700 draft_logprob_contents = draft_logprob_contents ,
692701 response_processor = response_processor ,
693702 prompt_logprobs_res_list = prompt_logprobs_res_list ,
703+ sampling_mask_list = sampling_mask_list ,
694704 max_tokens = max_tokens ,
695705 speculate_metrics = speculate_metrics [idx ],
696706 )
@@ -745,6 +755,7 @@ async def _create_chat_completion_choice(
745755 logprob_contents : list ,
746756 draft_logprob_contents : list ,
747757 prompt_logprobs_res_list : list ,
758+ sampling_mask_list : list ,
748759 response_processor : ChatResponseProcessor ,
749760 max_tokens : int ,
750761 speculate_metrics : SpeculateMetrics | None ,
@@ -782,6 +793,10 @@ async def _create_chat_completion_choice(
782793 draft_logprobs_full_res = LogProbs (content = draft_logprob_contents [idx ])
783794 if prompt_logprobs_res_list [idx ]:
784795 prompt_logprobs_full_res = prompt_logprobs_res_list [idx ]
796+ # Flatten per-step List[List[int]] into a single List[List[int]] over all tokens.
797+ sampling_mask_full_res = None
798+ if sampling_mask_list and sampling_mask_list [idx ]:
799+ sampling_mask_full_res = [mask for step in sampling_mask_list [idx ] for mask in step ]
785800
786801 num_cached_tokens [idx ] = data .get ("num_cached_tokens" , 0 )
787802 num_input_image_tokens [idx ] = data .get ("num_input_image_tokens" , 0 )
@@ -806,6 +821,7 @@ async def _create_chat_completion_choice(
806821 logprobs = logprobs_full_res ,
807822 draft_logprobs = draft_logprobs_full_res ,
808823 prompt_logprobs = prompt_logprobs_full_res ,
824+ sampling_mask = sampling_mask_full_res ,
809825 finish_reason = finish_reason ,
810826 speculate_metrics = speculate_metrics ,
811827 )
@@ -989,3 +1005,18 @@ def _make_logprob_dict(
9891005 )
9901006 for token_id , logprob , rank , token in zip (logprob_token_ids , logprobs , ranks , decoded_tokens )
9911007 }
1008+
1009+ @staticmethod
1010+ def _make_sampling_mask_list (sampling_mask ) -> List [List [int ]]:
1011+ """Wrap sampling_mask into a uniform List[List[int]] format.
1012+
1013+ sampling_mask is already in sparse-index form (no bool-to-index conversion needed):
1014+ Non-MTP: List[int] (indices for 1 token/step) → [[idx, ...]]
1015+ MTP: List[List[int]] (indices for N tokens/step) → [[idx, ...], ...]
1016+ """
1017+ assert sampling_mask is not None
1018+ if sampling_mask and len (sampling_mask ) > 0 and isinstance (sampling_mask [0 ], list ):
1019+ # MTP: already List[List[int]], return as-is
1020+ return sampling_mask
1021+ # Non-MTP: already List[int], wrap in outer list for uniform format
1022+ return [sampling_mask ]
0 commit comments