Skip to content

Commit 72bf3db

Browse files
authored
[KSM] support keep sampling mask (#7146)
* [KSM] support keep sampling mask * Remove Comments * remove logz_per_batch * fix the description and checking
1 parent 44ef7b6 commit 72bf3db

23 files changed

Lines changed: 519 additions & 8 deletions

fastdeploy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def __init__(
208208
self.enable_logprob = False
209209
self.max_logprobs = 20
210210
self.logprobs_mode = "raw_logprobs"
211+
self.enable_keep_sampling_mask = False
211212
self.redundant_experts_num = 0
212213
self.seed = 0
213214
self.quantization = None

fastdeploy/engine/args_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,14 @@ class EngineArgs:
459459
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
460460
"""
461461

462+
enable_keep_sampling_mask: bool = False
463+
"""
464+
When enabled, the server returns a sparse index list for each generated token, indicating
465+
which vocabulary positions were retained after top_p/top_k sampling, and streams it to
466+
the client. In MTP (multi-token prediction) scenarios this field is a List[List[int]],
467+
where each inner list contains the retained vocabulary indices for a predicted token.
468+
"""
469+
462470
max_logprobs: int = 20
463471
"""
464472
Maximum number of log probabilities to return when `enable_logprob` is True. The default value comes the default for the
@@ -872,6 +880,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
872880
default=EngineArgs.enable_logprob,
873881
help="Enable output of token-level log probabilities.",
874882
)
883+
model_group.add_argument(
884+
"--enable-keep-sampling-mask",
885+
action="store_true",
886+
default=EngineArgs.enable_keep_sampling_mask,
887+
help=(
888+
"Enable output of sampling mask as a sparse index list over the vocabulary. "
889+
"For non-MTP decoding, this is a list[int] per token step indicating which "
890+
"vocabulary indices were kept after top_p/top_k sampling. "
891+
"For MTP decoding, this is a list[list[int]] per token step, where each inner "
892+
"list corresponds to one MTP group."
893+
),
894+
)
875895
model_group.add_argument(
876896
"--max-logprobs",
877897
type=int,

fastdeploy/engine/common_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2371,6 +2371,7 @@ def _start_worker_service(self):
23712371
"lm_head_fp32": self.cfg.model_config.lm_head_fp32,
23722372
"enable_entropy": self.cfg.model_config.enable_entropy,
23732373
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
2374+
"enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask,
23742375
}
23752376
for worker_flag, value in worker_store_true_flag.items():
23762377
if value:

fastdeploy/engine/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ def _start_worker_service(self):
621621
"shutdown_comm_group_if_worker_idle": self.cfg.parallel_config.shutdown_comm_group_if_worker_idle,
622622
"enable_entropy": self.cfg.model_config.enable_entropy,
623623
"enable_overlap_schedule": self.cfg.scheduler_config.enable_overlap_schedule,
624+
"enable_keep_sampling_mask": self.cfg.model_config.enable_keep_sampling_mask,
624625
}
625626
for worker_flag, value in worker_store_true_flag.items():
626627
if value:

fastdeploy/engine/request.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,10 @@ class CompletionOutput:
727727
multipart: Optional[list[Any]] = None
728728
num_image_tokens: Optional[int] = None
729729
enable_parser: bool = False
730+
# Sparse indices of retained vocab ids:
731+
# - Non-MTP: list[int]
732+
# - MTP: list[list[int]]
733+
sampling_mask: Optional[Any] = None
730734

731735
def to_dict(self):
732736
"""
@@ -745,6 +749,7 @@ def to_dict(self):
745749
"text": self.text,
746750
"reasoning_content": self.reasoning_content,
747751
"reasoning_token_num": self.reasoning_token_num,
752+
"sampling_mask": self.sampling_mask,
748753
}
749754

750755
@classmethod

fastdeploy/entrypoints/openai/protocol.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,8 @@ class ChatCompletionResponseChoice(BaseModel):
268268
logprobs: Optional[LogProbs] = None
269269
draft_logprobs: Optional[LogProbs] = None
270270
prompt_logprobs: Optional[PromptLogprobs] = None
271+
# Per-token retained vocab indices from top_p/top_k sampling: List[List[int]], one list of vocab indices per token
272+
sampling_mask: Optional[List[List[int]]] = None
271273
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]]
272274
speculate_metrics: Optional[SpeculateMetrics] = None
273275

@@ -333,6 +335,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
333335
logprobs: Optional[LogProbs] = None
334336
draft_logprobs: Optional[LogProbs] = None
335337
prompt_logprobs: Optional[PromptLogprobs] = None
338+
# Per-token index list of retained positions after top_p sampling.
339+
# Non-MTP: [[idx, ...]] (1 token/step). MTP: [[idx, ...], ...] (N accepted tokens/step).
340+
sampling_mask: Optional[List[List[int]]] = None
336341
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop", "abort"]] = None
337342
arrival_time: Optional[float] = None
338343
speculate_metrics: Optional[SpeculateMetrics] = None

fastdeploy/entrypoints/openai/serving_chat.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,11 @@ async def chat_completion_stream_generator(
432432
delta=delta_message,
433433
logprobs=logprobs_res,
434434
draft_logprobs=draft_logprobs_res,
435+
sampling_mask=(
436+
self._make_sampling_mask_list(output["sampling_mask"])
437+
if output.get("sampling_mask") is not None
438+
else None
439+
),
435440
arrival_time=arrival_time,
436441
speculate_metrics=output_speculate_metrics,
437442
)
@@ -577,6 +582,7 @@ async def chat_completion_full_generator(
577582
decoder_base_url=self.tokenizer_base_url,
578583
)
579584
prompt_logprobs_res_list = [[] for _ in range(num_choices)]
585+
sampling_mask_list = [[] for _ in range(num_choices)]
580586
speculate_metrics = [None for _ in range(num_choices)]
581587
choices = []
582588
while num_choices > 0:
@@ -656,6 +662,9 @@ async def chat_completion_full_generator(
656662
)
657663
if prompt_logprobs_res:
658664
prompt_logprobs_res_list[idx].extend(clamp_prompt_logprobs(prompt_logprobs_res))
665+
output_sampling_mask = output.get("sampling_mask", None)
666+
if output_sampling_mask is not None:
667+
sampling_mask_list[idx].append(self._make_sampling_mask_list(output_sampling_mask))
659668
speculate_metrics[idx] = data["metrics"].get("speculate_metrics", None)
660669
if data["finished"]:
661670
trace_carrier = data.get("trace_carrier")
@@ -691,6 +700,7 @@ async def chat_completion_full_generator(
691700
draft_logprob_contents=draft_logprob_contents,
692701
response_processor=response_processor,
693702
prompt_logprobs_res_list=prompt_logprobs_res_list,
703+
sampling_mask_list=sampling_mask_list,
694704
max_tokens=max_tokens,
695705
speculate_metrics=speculate_metrics[idx],
696706
)
@@ -745,6 +755,7 @@ async def _create_chat_completion_choice(
745755
logprob_contents: list,
746756
draft_logprob_contents: list,
747757
prompt_logprobs_res_list: list,
758+
sampling_mask_list: list,
748759
response_processor: ChatResponseProcessor,
749760
max_tokens: int,
750761
speculate_metrics: SpeculateMetrics | None,
@@ -782,6 +793,10 @@ async def _create_chat_completion_choice(
782793
draft_logprobs_full_res = LogProbs(content=draft_logprob_contents[idx])
783794
if prompt_logprobs_res_list[idx]:
784795
prompt_logprobs_full_res = prompt_logprobs_res_list[idx]
796+
# Flatten per-step List[List[int]] into a single List[List[int]] over all tokens.
797+
sampling_mask_full_res = None
798+
if sampling_mask_list and sampling_mask_list[idx]:
799+
sampling_mask_full_res = [mask for step in sampling_mask_list[idx] for mask in step]
785800

786801
num_cached_tokens[idx] = data.get("num_cached_tokens", 0)
787802
num_input_image_tokens[idx] = data.get("num_input_image_tokens", 0)
@@ -806,6 +821,7 @@ async def _create_chat_completion_choice(
806821
logprobs=logprobs_full_res,
807822
draft_logprobs=draft_logprobs_full_res,
808823
prompt_logprobs=prompt_logprobs_full_res,
824+
sampling_mask=sampling_mask_full_res,
809825
finish_reason=finish_reason,
810826
speculate_metrics=speculate_metrics,
811827
)
@@ -989,3 +1005,18 @@ def _make_logprob_dict(
9891005
)
9901006
for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens)
9911007
}
1008+
1009+
@staticmethod
1010+
def _make_sampling_mask_list(sampling_mask) -> List[List[int]]:
1011+
"""Wrap sampling_mask into a uniform List[List[int]] format.
1012+
1013+
sampling_mask is already in sparse-index form (no bool-to-index conversion needed):
1014+
Non-MTP: List[int] (indices for 1 token/step) → [[idx, ...]]
1015+
MTP: List[List[int]] (indices for N tokens/step) → [[idx, ...], ...]
1016+
"""
1017+
assert sampling_mask is not None
1018+
if sampling_mask and len(sampling_mask) > 0 and isinstance(sampling_mask[0], list):
1019+
# MTP: already List[List[int]], return as-is
1020+
return sampling_mask
1021+
# Non-MTP: already List[int], wrap in outer list for uniform format
1022+
return [sampling_mask]

fastdeploy/model_executor/layers/sample/meta_data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,5 @@ class SamplingMetadata:
6464
# Add for HPU post-processing
6565
seq_lens_encoder: Optional[paddle.Tensor] = None
6666
seq_lens_decoder: Optional[paddle.Tensor] = None
67+
# Add for keep sampling mask
68+
keep_sampling_mask: Optional[bool] = None

fastdeploy/model_executor/layers/sample/ops/top_k_top_p_sampling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,11 @@ def top_k_top_p_sampling(
8585

8686
_, ids = native_top_p_sampling(x, top_p)
8787
else:
88+
if top_k_list and any(x > 0 for x in top_k_list):
89+
from fastdeploy.model_executor.ops.gpu import top_k_renorm_probs
90+
91+
x = top_k_renorm_probs(x, top_k)
92+
8893
_, ids = paddle.tensor.top_p_sampling(
8994
x,
9095
top_p,

0 commit comments

Comments
 (0)