Skip to content

Commit 07a4723

Browse files
authored
Refactor step inputs (#4504)
* refactor step-inputs * update recordfunc * inline decode
1 parent 05bad87 commit 07a4723

20 files changed

Lines changed: 870 additions & 747 deletions

lmdeploy/pytorch/engine/model_agent/agent.py

Lines changed: 14 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from lmdeploy.pytorch.distributed import DistContext, get_dist_manager
2121
from lmdeploy.pytorch.engine.cache_engine import CacheEngine, StateCacheEngine
2222
from lmdeploy.pytorch.engine.guided_process import GuidedDecodingManager
23-
from lmdeploy.pytorch.engine.logits_process import FusedLogitsProcessor, SamplingInputs, SamplingInputsDelta
23+
from lmdeploy.pytorch.engine.logits_process import FusedLogitsProcessor, SamplingInputs
2424
from lmdeploy.pytorch.model_inputs import ModelInputs, ModelInputsDelta, step_ctx_manager
2525
from lmdeploy.pytorch.models.patch import BuildModelContext, add_adapters, build_patched_model, update_custom_module_map
2626
from lmdeploy.pytorch.spec_decode import build_spec_agent
@@ -229,92 +229,6 @@ async def async_wait(self, timeout: float = 0.001):
229229
SwapMap = dict[int, int]
230230

231231

232-
@dataclass
233-
class StepInputs:
234-
"""Step inputs."""
235-
model_inputs: ModelInputs = None
236-
extra_inputs: ExtraInputs = None
237-
stopping_criteria: StoppingCriteria = None
238-
sampling_delta: SamplingInputsDelta = None
239-
240-
@record_function('StepInputs.merge')
241-
def merge(
242-
self,
243-
inputs: ModelInputs,
244-
extra_inputs: ExtraInputs,
245-
stopping_criteria: StoppingCriteria,
246-
sampling_delta: SamplingInputsDelta,
247-
next_token_ids: torch.Tensor,
248-
model_metas,
249-
extra_outputs: ExtraOutputs,
250-
model_agent: 'BaseModelAgent',
251-
):
252-
"""Merge prefill inputs."""
253-
inputs, extra_inputs = model_agent.agent_strategy.update_prefill_for_next_step(
254-
inputs,
255-
extra_inputs,
256-
next_token_ids,
257-
model_metas,
258-
extra_outputs,
259-
)
260-
stopping_criteria = stopping_criteria.clone()
261-
sampling_delta = model_agent.sampling_strategy.step_sampling_delta(sampling_delta,
262-
next_token_ids,
263-
extra_inputs=extra_inputs)
264-
if self.model_inputs is None:
265-
self.model_inputs = inputs
266-
self.extra_inputs = extra_inputs
267-
self.stopping_criteria = stopping_criteria
268-
self.sampling_delta = sampling_delta
269-
else:
270-
self.model_inputs = model_agent.inputs_strategy.merge(self.model_inputs, inputs)
271-
self.extra_inputs = self.extra_inputs.merge(extra_inputs)
272-
self.stopping_criteria = self.stopping_criteria.merge(stopping_criteria)
273-
self.sampling_delta = model_agent.sampling_strategy.merge_sampling_delta(
274-
self.sampling_delta, sampling_delta)
275-
276-
def update_delta(
277-
self,
278-
delta: ModelInputsDelta,
279-
model_agent: 'BaseModelAgent',
280-
):
281-
"""Get inputs from delta."""
282-
self.model_inputs = model_agent.inputs_strategy.update_inputs(self.model_inputs, delta)
283-
self.extra_inputs = model_agent.agent_strategy.update_extra_inputs(self.extra_inputs, delta)
284-
self.stopping_criteria = self.stopping_criteria.update(delta)
285-
self.sampling_delta = model_agent.sampling_strategy.update_sampling_delta(self.sampling_delta, delta)
286-
287-
@record_function('StepInputs.step')
288-
def step(
289-
self,
290-
model_inputs: ModelInputs,
291-
extra_inputs: ExtraInputs,
292-
stopping_criteria: StoppingCriteria,
293-
sampling_delta: SamplingInputsDelta,
294-
next_token_ids: torch.Tensor,
295-
model_metas,
296-
extra_outputs: ExtraOutputs,
297-
model_agent: 'BaseModelAgent',
298-
):
299-
"""Update inputs."""
300-
# dp might change is_decoding of decoding inputs
301-
model_inputs.is_decoding = True
302-
(
303-
self.model_inputs,
304-
self.extra_inputs,
305-
) = model_agent.agent_strategy.update_decoding_for_next_step(
306-
model_inputs,
307-
next_token_ids=next_token_ids,
308-
model_metas=model_metas,
309-
extra_inputs=extra_inputs,
310-
extra_outputs=extra_outputs,
311-
)
312-
self.stopping_criteria = stopping_criteria.clone()
313-
self.sampling_delta = model_agent.sampling_strategy.step_sampling_delta(sampling_delta,
314-
next_token_ids,
315-
extra_inputs=extra_inputs)
316-
317-
318232
class BaseModelAgent:
319233
"""Base model agent.
320234
@@ -421,7 +335,7 @@ def __init__(
421335
self.state: SleepWakeupState = SleepWakeupState()
422336

423337
# decoding inputs
424-
self.step_inputs = StepInputs()
338+
self.step_inputs = self.strategy_factory.build_step_inputs()
425339

426340
# long context
427341
self._prev_chunk_output: dict = None
@@ -644,7 +558,7 @@ def _get_inputs_from_delta(
644558
sampling_inputs: SamplingInputs,
645559
):
646560
"""Get inputs from delta."""
647-
self.step_inputs.update_delta(delta, self)
561+
self.step_inputs.reindex(delta)
648562
inputs = self.step_inputs.model_inputs
649563
extra_inputs = self.step_inputs.extra_inputs
650564
stopping_criteria = self.step_inputs.stopping_criteria
@@ -661,7 +575,7 @@ def _prepare_inputs_prefill(
661575
if delta is not None:
662576
# update decoding inputs with delta
663577
# for second round chat
664-
self.step_inputs.update_delta(delta, self)
578+
self.step_inputs.reindex(delta)
665579

666580
if inputs.is_first_chunk:
667581
self._prev_chunk_output = None
@@ -768,29 +682,6 @@ async def _async_step(
768682
):
769683
"""Asyc forward task."""
770684

771-
@record_function('update_decoding_for_next_step')
772-
def __update_inputs(
773-
inputs,
774-
next_token_ids,
775-
model_metas,
776-
extra_inputs,
777-
extra_outputs,
778-
stopping_criteria,
779-
sampling_delta: SamplingInputsDelta = None,
780-
):
781-
"""Update inputs."""
782-
# dp might change is_decoding of decoding inputs
783-
self.step_inputs.step(
784-
inputs,
785-
extra_inputs,
786-
stopping_criteria,
787-
sampling_delta,
788-
next_token_ids,
789-
model_metas,
790-
extra_outputs,
791-
model_agent=self,
792-
)
793-
794685
dist_ctx = get_dist_manager().current_context()
795686
dist_config = dist_ctx.dist_config
796687
rank = self.rank
@@ -904,26 +795,27 @@ def __update_inputs(
904795

905796
sampling_delta = sampling_inputs.get_delta()
906797
if need_update_inputs:
907-
__update_inputs(inputs,
908-
next_token_ids,
909-
model_metas,
910-
extra_inputs,
911-
extra_outputs,
912-
stopping_criteria,
913-
sampling_delta=sampling_delta)
798+
self.step_inputs.step_decode(
799+
inputs,
800+
extra_inputs,
801+
stopping_criteria,
802+
sampling_delta,
803+
next_token_ids,
804+
model_metas,
805+
extra_outputs,
806+
)
914807
elif inputs.is_chunk and not inputs.is_last_chunk:
915808
# _prev_chunk_output is used to update model metas
916809
self._prev_chunk_output = output
917810
elif self.cache_config.role != EngineRole.Prefill:
918-
self.step_inputs.merge(
811+
self.step_inputs.merge_prefill(
919812
inputs,
920813
extra_inputs,
921814
stopping_criteria,
922815
sampling_delta,
923816
next_token_ids,
924817
model_metas,
925818
extra_outputs,
926-
model_agent=self,
927819
)
928820

929821
async def _async_loop_background(self, forward_event: asyncio.Event = None):

lmdeploy/pytorch/strategies/ar/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from lmdeploy.pytorch.strategies.base.model_agent import ModelAgentStrategy
1212
from lmdeploy.pytorch.strategies.base.model_inputs import ModelInputsStrategy
1313
from lmdeploy.pytorch.strategies.base.sampling import SamplingStrategy
14+
from lmdeploy.pytorch.strategies.base.step_inputs import StepInputs
1415

1516
from ..base import StrategyFactoryBase
1617

@@ -52,3 +53,10 @@ def build_engine_strategy(self, cache_config: 'CacheConfig',
5253
def build_sequence_strategy(self) -> SequenceStrategy:
5354
from .sequence import ARSequenceStrategy
5455
return ARSequenceStrategy()
56+
57+
def build_step_inputs(self) -> 'StepInputs':
58+
"""Build step inputs for the decoding loop."""
59+
from .step_inputs import ARStepInputs
60+
pad_token_id = self.model_config.bos_token_id
61+
pad_token_id = 0 if pad_token_id is None else pad_token_id
62+
return ARStepInputs(_pad_token_id=pad_token_id)

lmdeploy/pytorch/strategies/ar/model_agent.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from contextlib import contextmanager
33
from dataclasses import dataclass
4-
from typing import Any
54

65
import torch
76
import torch.distributed as dist
@@ -17,37 +16,6 @@
1716
SeqList = list[SchedulerSequence]
1817

1918

20-
def get_model_inputs_next_decoding(inputs: ModelInputs, input_ids: torch.Tensor, max_q_seqlen: int,
21-
model_metas) -> ModelInputs:
22-
"""Next decoding step."""
23-
if input_ids.dim() == 1:
24-
input_ids = input_ids[None, :]
25-
state_offsets = inputs.state_offsets
26-
if state_offsets is not None:
27-
state_offsets = state_offsets.clone()
28-
29-
# mrope
30-
mrope_pos_ids = inputs.mrope_pos_ids
31-
if mrope_pos_ids is not None:
32-
index = inputs.seq_length.cumsum(0) - 1
33-
mrope_pos_ids = mrope_pos_ids[:, index] + 1
34-
return ModelInputs(
35-
input_ids=input_ids,
36-
seq_length=torch.full_like(inputs.seq_length, max_q_seqlen),
37-
history_lengths=inputs.history_lengths + inputs.seq_length,
38-
block_offsets=inputs.block_offsets,
39-
is_decoding=True,
40-
num_ignored_history=inputs.num_ignored_history.clone(),
41-
max_q_seqlen=max_q_seqlen,
42-
max_kv_seqlen=inputs.max_kv_seqlen + max_q_seqlen,
43-
sum_kv_seqlen=inputs.sum_kv_seqlen + inputs.seq_length.numel() * inputs.max_q_seqlen,
44-
local_adapter_ids=inputs.local_adapter_ids,
45-
model_metas=model_metas,
46-
state_offsets=state_offsets,
47-
mrope_pos_ids=mrope_pos_ids,
48-
)
49-
50-
5119
@dataclass
5220
class ARExtraInputs(ExtraInputs):
5321
"""Ar extra inputs."""
@@ -145,26 +113,6 @@ def make_extra_outputs(self, extra_inputs: ARExtraInputs) -> ARExtraOutputs:
145113
"""Create extra outputs."""
146114
return ARExtraOutputs()
147115

148-
def update_prefill_for_next_step(
149-
self,
150-
model_inputs: 'ModelInputs',
151-
extra_inputs: ARExtraInputs,
152-
next_token_ids: torch.Tensor,
153-
model_metas: Any,
154-
extra_outputs: ARExtraOutputs,
155-
) -> tuple['ModelInputs', ARExtraInputs]:
156-
"""Step next decoding."""
157-
inputs = get_model_inputs_next_decoding(model_inputs, next_token_ids, max_q_seqlen=1, model_metas=model_metas)
158-
return inputs, extra_inputs
159-
160-
def update_decoding_for_next_step(self, model_inputs: 'ModelInputs', next_token_ids: torch.Tensor, model_metas: Any,
161-
extra_inputs: ARExtraInputs, **kwargs):
162-
"""Step next inputs."""
163-
model_inputs.model_metas = model_metas
164-
step_seqlens = model_inputs.seq_length
165-
model_inputs = model_inputs.step(next_token_ids, step_seqlens)
166-
return model_inputs, extra_inputs
167-
168116
def post_sampling(self, inputs: 'ModelInputs', logits: torch.Tensor, next_token_ids: torch.LongTensor,
169117
extra_inputs: ARExtraInputs):
170118
"""Post sampling."""

0 commit comments

Comments
 (0)