2020from lmdeploy .pytorch .distributed import DistContext , get_dist_manager
2121from lmdeploy .pytorch .engine .cache_engine import CacheEngine , StateCacheEngine
2222from lmdeploy .pytorch .engine .guided_process import GuidedDecodingManager
23- from lmdeploy .pytorch .engine .logits_process import FusedLogitsProcessor , SamplingInputs , SamplingInputsDelta
23+ from lmdeploy .pytorch .engine .logits_process import FusedLogitsProcessor , SamplingInputs
2424from lmdeploy .pytorch .model_inputs import ModelInputs , ModelInputsDelta , step_ctx_manager
2525from lmdeploy .pytorch .models .patch import BuildModelContext , add_adapters , build_patched_model , update_custom_module_map
2626from lmdeploy .pytorch .spec_decode import build_spec_agent
@@ -229,92 +229,6 @@ async def async_wait(self, timeout: float = 0.001):
229229SwapMap = dict [int , int ]
230230
231231
232- @dataclass
233- class StepInputs :
234- """Step inputs."""
235- model_inputs : ModelInputs = None
236- extra_inputs : ExtraInputs = None
237- stopping_criteria : StoppingCriteria = None
238- sampling_delta : SamplingInputsDelta = None
239-
240- @record_function ('StepInputs.merge' )
241- def merge (
242- self ,
243- inputs : ModelInputs ,
244- extra_inputs : ExtraInputs ,
245- stopping_criteria : StoppingCriteria ,
246- sampling_delta : SamplingInputsDelta ,
247- next_token_ids : torch .Tensor ,
248- model_metas ,
249- extra_outputs : ExtraOutputs ,
250- model_agent : 'BaseModelAgent' ,
251- ):
252- """Merge prefill inputs."""
253- inputs , extra_inputs = model_agent .agent_strategy .update_prefill_for_next_step (
254- inputs ,
255- extra_inputs ,
256- next_token_ids ,
257- model_metas ,
258- extra_outputs ,
259- )
260- stopping_criteria = stopping_criteria .clone ()
261- sampling_delta = model_agent .sampling_strategy .step_sampling_delta (sampling_delta ,
262- next_token_ids ,
263- extra_inputs = extra_inputs )
264- if self .model_inputs is None :
265- self .model_inputs = inputs
266- self .extra_inputs = extra_inputs
267- self .stopping_criteria = stopping_criteria
268- self .sampling_delta = sampling_delta
269- else :
270- self .model_inputs = model_agent .inputs_strategy .merge (self .model_inputs , inputs )
271- self .extra_inputs = self .extra_inputs .merge (extra_inputs )
272- self .stopping_criteria = self .stopping_criteria .merge (stopping_criteria )
273- self .sampling_delta = model_agent .sampling_strategy .merge_sampling_delta (
274- self .sampling_delta , sampling_delta )
275-
276- def update_delta (
277- self ,
278- delta : ModelInputsDelta ,
279- model_agent : 'BaseModelAgent' ,
280- ):
281- """Get inputs from delta."""
282- self .model_inputs = model_agent .inputs_strategy .update_inputs (self .model_inputs , delta )
283- self .extra_inputs = model_agent .agent_strategy .update_extra_inputs (self .extra_inputs , delta )
284- self .stopping_criteria = self .stopping_criteria .update (delta )
285- self .sampling_delta = model_agent .sampling_strategy .update_sampling_delta (self .sampling_delta , delta )
286-
287- @record_function ('StepInputs.step' )
288- def step (
289- self ,
290- model_inputs : ModelInputs ,
291- extra_inputs : ExtraInputs ,
292- stopping_criteria : StoppingCriteria ,
293- sampling_delta : SamplingInputsDelta ,
294- next_token_ids : torch .Tensor ,
295- model_metas ,
296- extra_outputs : ExtraOutputs ,
297- model_agent : 'BaseModelAgent' ,
298- ):
299- """Update inputs."""
300- # dp might change is_decoding of decoding inputs
301- model_inputs .is_decoding = True
302- (
303- self .model_inputs ,
304- self .extra_inputs ,
305- ) = model_agent .agent_strategy .update_decoding_for_next_step (
306- model_inputs ,
307- next_token_ids = next_token_ids ,
308- model_metas = model_metas ,
309- extra_inputs = extra_inputs ,
310- extra_outputs = extra_outputs ,
311- )
312- self .stopping_criteria = stopping_criteria .clone ()
313- self .sampling_delta = model_agent .sampling_strategy .step_sampling_delta (sampling_delta ,
314- next_token_ids ,
315- extra_inputs = extra_inputs )
316-
317-
318232class BaseModelAgent :
319233 """Base model agent.
320234
@@ -421,7 +335,7 @@ def __init__(
421335 self .state : SleepWakeupState = SleepWakeupState ()
422336
423337 # decoding inputs
424- self .step_inputs = StepInputs ()
338+ self .step_inputs = self . strategy_factory . build_step_inputs ()
425339
426340 # long context
427341 self ._prev_chunk_output : dict = None
@@ -644,7 +558,7 @@ def _get_inputs_from_delta(
644558 sampling_inputs : SamplingInputs ,
645559 ):
646560 """Get inputs from delta."""
647- self .step_inputs .update_delta (delta , self )
561+ self .step_inputs .reindex (delta )
648562 inputs = self .step_inputs .model_inputs
649563 extra_inputs = self .step_inputs .extra_inputs
650564 stopping_criteria = self .step_inputs .stopping_criteria
@@ -661,7 +575,7 @@ def _prepare_inputs_prefill(
661575 if delta is not None :
662576 # update decoding inputs with delta
663577 # for second round chat
664- self .step_inputs .update_delta (delta , self )
578+ self .step_inputs .reindex (delta )
665579
666580 if inputs .is_first_chunk :
667581 self ._prev_chunk_output = None
@@ -768,29 +682,6 @@ async def _async_step(
768682 ):
769683 """Asyc forward task."""
770684
771- @record_function ('update_decoding_for_next_step' )
772- def __update_inputs (
773- inputs ,
774- next_token_ids ,
775- model_metas ,
776- extra_inputs ,
777- extra_outputs ,
778- stopping_criteria ,
779- sampling_delta : SamplingInputsDelta = None ,
780- ):
781- """Update inputs."""
782- # dp might change is_decoding of decoding inputs
783- self .step_inputs .step (
784- inputs ,
785- extra_inputs ,
786- stopping_criteria ,
787- sampling_delta ,
788- next_token_ids ,
789- model_metas ,
790- extra_outputs ,
791- model_agent = self ,
792- )
793-
794685 dist_ctx = get_dist_manager ().current_context ()
795686 dist_config = dist_ctx .dist_config
796687 rank = self .rank
@@ -904,26 +795,27 @@ def __update_inputs(
904795
905796 sampling_delta = sampling_inputs .get_delta ()
906797 if need_update_inputs :
907- __update_inputs (inputs ,
908- next_token_ids ,
909- model_metas ,
910- extra_inputs ,
911- extra_outputs ,
912- stopping_criteria ,
913- sampling_delta = sampling_delta )
798+ self .step_inputs .step_decode (
799+ inputs ,
800+ extra_inputs ,
801+ stopping_criteria ,
802+ sampling_delta ,
803+ next_token_ids ,
804+ model_metas ,
805+ extra_outputs ,
806+ )
914807 elif inputs .is_chunk and not inputs .is_last_chunk :
915808 # _prev_chunk_output is used to update model metas
916809 self ._prev_chunk_output = output
917810 elif self .cache_config .role != EngineRole .Prefill :
918- self .step_inputs .merge (
811+ self .step_inputs .merge_prefill (
919812 inputs ,
920813 extra_inputs ,
921814 stopping_criteria ,
922815 sampling_delta ,
923816 next_token_ids ,
924817 model_metas ,
925818 extra_outputs ,
926- model_agent = self ,
927819 )
928820
929821 async def _async_loop_background (self , forward_event : asyncio .Event = None ):
0 commit comments