diff --git a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py index 572168972..8f476530d 100644 --- a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py +++ b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py @@ -103,17 +103,18 @@ def load_optimizer_checkpoint_(self, optimizer: Optimizer, model: FSDP, file_pat class DCPCheckpointLoading(DistributedCheckpointLoadingIF): """Distributed checkpoint loading interface for loading PyTorch models and optimizer checkpoints.""" - def __init__(self, global_rank: int): + def __init__(self, global_rank: int, allow_partial_load: bool = False): """ Initializes the DCPCheckpointLoading object. Args: global_rank (int): The global rank of the process. - + allow_partial_load (bool, optional): Whether to allow partial loading of the checkpoint. Defaults to True. Returns: None """ self._global_rank = global_rank + self._allow_partial_load = allow_partial_load @torch.no_grad() def load_checkpoint_(self, app_state: AppState, checkpoint_dir_path: Path): @@ -129,5 +130,6 @@ def load_checkpoint_(self, app_state: AppState, checkpoint_dir_path: Path): dcp.load( state_dict={"app": app_state}, checkpoint_id=checkpoint_dir_path, + planner=dcp.DefaultLoadPlanner(allow_partial_load=self._allow_partial_load), ) get_logger().info(f"Distributed checkpoint loaded on rank {self._global_rank}.") diff --git a/src/modalities/checkpointing/stateful/app_state.py b/src/modalities/checkpointing/stateful/app_state.py index 2da3ab236..97b65a569 100644 --- a/src/modalities/checkpointing/stateful/app_state.py +++ b/src/modalities/checkpointing/stateful/app_state.py @@ -37,7 +37,11 @@ class AppState(Stateful): """ def __init__( - self, model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None + self, + model: nn.Module | list[nn.Module], + optimizer: Optimizer, + lr_scheduler: Optional[LRScheduler] = None, + components_to_load: list[StatefulComponents] | None = None, ): """Initializes the AppState object. @@ -46,12 +50,22 @@ def __init__( a non-sharded model, FSDP1 or FSDP2 model. optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer. lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None. + components_to_load (list[StatefulComponents] | None, optional): The list of components to load from the + checkpoint. If None, all components are loaded. Defaults to None. """ self._model_parts = list(model) if isinstance(model, list) else [model] self._optimizer = optimizer self._lr_scheduler = lr_scheduler self._is_loaded = False + # policy for which components to load from the checkpoint. If None, defaults to loading all components. + if components_to_load is None: + self._components_to_load = [StatefulComponents.MODEL, StatefulComponents.OPTIMIZER] + if lr_scheduler is not None: + self._components_to_load.append(StatefulComponents.LR_SCHEDULER) + else: + self._components_to_load = components_to_load + @property def is_loaded(self) -> bool: """Returns whether the state dict has been loaded. @@ -106,12 +120,14 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: "Cannot call load_state_dict twice on the same AppState object. " "State dict has already been loaded." ) - ModelStateRetriever.load_state_dict_(app_state=self, state_dict=state_dict[StatefulComponents.MODEL.value]) - OptimizerStateRetriever.load_state_dict_( - app_state=self, - state_dict=state_dict[StatefulComponents.OPTIMIZER.value], - ) - if self._lr_scheduler is not None: + if StatefulComponents.MODEL in self._components_to_load: + ModelStateRetriever.load_state_dict_(app_state=self, state_dict=state_dict[StatefulComponents.MODEL.value]) + if StatefulComponents.OPTIMIZER in self._components_to_load: + OptimizerStateRetriever.load_state_dict_( + app_state=self, + state_dict=state_dict[StatefulComponents.OPTIMIZER.value], + ) + if self._lr_scheduler is not None and StatefulComponents.LR_SCHEDULER in self._components_to_load: LRSchedulerStateRetriever.load_state_dict_( app_state=self, state_dict=state_dict[StatefulComponents.LR_SCHEDULER.value] ) diff --git a/src/modalities/checkpointing/stateful/app_state_factory.py b/src/modalities/checkpointing/stateful/app_state_factory.py index 8f6e63d8a..618794ac5 100644 --- a/src/modalities/checkpointing/stateful/app_state_factory.py +++ b/src/modalities/checkpointing/stateful/app_state_factory.py @@ -7,7 +7,7 @@ from torch.optim.lr_scheduler import LRScheduler from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import DCPCheckpointLoading -from modalities.checkpointing.stateful.app_state import AppState +from modalities.checkpointing.stateful.app_state import AppState, StatefulComponents class AppStateFactory: @@ -15,7 +15,10 @@ class AppStateFactory: @staticmethod def get_raw_app_state( - model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None + model: nn.Module | list[nn.Module], + optimizer: Optimizer, + lr_scheduler: Optional[LRScheduler] = None, + components_to_load: list[StatefulComponents] | None = None, ) -> AppState: """Creates a new (non-checkpoint loaded) AppState object from an instantiated model, optimizer, and optional learning rate scheduler. @@ -25,24 +28,35 @@ def get_raw_app_state( a non-sharded model, FSDP1 or FSDP2 model. optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer. lr_scheduler (Optional[LRScheduler], optional): Lr scheduler used during training. Defaults to None. + components_to_load (list[StatefulComponents] | None, optional): Subset of components that should + be restored from a checkpoint when ``load_state_dict`` is later invoked. If None, all + available components are loaded. Defaults to None. Returns: AppState: The AppState object. """ - app_state = AppState(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler) + app_state = AppState( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + components_to_load=components_to_load, + ) return app_state @staticmethod def get_dcp_checkpointed_app_state_( raw_app_state: AppState, checkpoint_dir_path: Path, + allow_partial_load: bool = True, ) -> AppState: """Loads the checkpointed state dict into the raw AppState object (i.e., non-checkpoint loaded AppState) in-place. Args: - raw_app_state (AppState): The raw AppState object. + raw_app_state (AppState): The raw AppState object. Its ``components_to_load`` policy + determines which components are restored. checkpoint_dir_path (Path): The path to the checkpoint directory. + allow_partial_load (bool, optional): Whether to allow partial loading of the checkpoint. Defaults to True. Raises: RuntimeError: Raises an error if the state dict has already been loaded. @@ -52,8 +66,9 @@ def get_dcp_checkpointed_app_state_( """ if raw_app_state.is_loaded: raise RuntimeError( - "Cannot call load_state_dict twice on the same AppState object. " "State dict has already been loaded." + "Cannot call load_state_dict twice on the same AppState object. State dict has already been loaded." ) - cp_loading = DCPCheckpointLoading(global_rank=dist.get_rank()) + + cp_loading = DCPCheckpointLoading(global_rank=dist.get_rank(), allow_partial_load=allow_partial_load) cp_loading.load_checkpoint_(app_state=raw_app_state, checkpoint_dir_path=checkpoint_dir_path) return raw_app_state diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 42a19b99a..18d3629ca 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -11,6 +11,7 @@ from transformers import LlamaTokenizer as LlamaTokenizerFast from typing_extensions import deprecated +from modalities.checkpointing.stateful.app_state import StatefulComponents from modalities.config.lookup_enum import LookupEnum from modalities.config.pydantic_if_types import ( PydanticAppStateType, @@ -124,8 +125,9 @@ def parse_sharding_strategy_by_name(cls, name: str) -> ShardingStrategy: return parse_enum_by_name(name=name, enum_type=ShardingStrategy) -class DCPCheckpointLoadingConfig(BaseModel): - global_rank: Annotated[int, Field(strict=True, ge=0)] +# class DCPCheckpointLoadingConfig(BaseModel): +# global_rank: Annotated[int, Field(strict=True, ge=0)] +# allow_partial_load: bool = True class FSDP1CheckpointSavingConfig(BaseModel): @@ -382,11 +384,13 @@ class RawAppStateConfig(BaseModel): model: PydanticPytorchModuleOrListType optimizer: PydanticOptimizerIFType lr_scheduler: Optional[PydanticLRSchedulerIFType] = None + components_to_load: Optional[list[StatefulComponents]] = None class DCPAppStateConfig(BaseModel): raw_app_state: PydanticAppStateType checkpoint_dir_path: Path + allow_partial_load: bool = False class PreTrainedHFTokenizerConfig(BaseModel): diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 26df9b432..5833ea728 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -13,7 +13,7 @@ SaveEveryKStepsCheckpointingStrategy, SaveKMostRecentCheckpointsStrategy, ) -from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import DCPCheckpointLoading, FSDP1CheckpointLoading +from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import FSDP1CheckpointLoading from modalities.checkpointing.fsdp.fsdp_checkpoint_saving import DCPCheckpointSaving, FSDP1CheckpointSaving from modalities.checkpointing.stateful.app_state_factory import AppStateFactory from modalities.checkpointing.torch.torch_checkpoint_loading import TorchCheckpointLoading @@ -29,7 +29,6 @@ ConstantLRSchedulerConfig, CosineAnnealingLRSchedulerConfig, DCPAppStateConfig, - DCPCheckpointLoadingConfig, DCPCheckpointSavingConfig, DebuggingEnrichedModelConfig, DistributedSamplerConfig, @@ -358,7 +357,7 @@ class ComponentEntity: ComponentEntity("checkpoint_saving_execution", "dcp", DCPCheckpointSaving, DCPCheckpointSavingConfig), # checkpoint loading ComponentEntity("checkpoint_loading", "fsdp1", FSDP1CheckpointLoading, FSDP1CheckpointLoadingConfig), - ComponentEntity("checkpoint_loading", "dcp", DCPCheckpointLoading, DCPCheckpointLoadingConfig), + # ComponentEntity("checkpoint_loading", "dcp", DCPCheckpointLoading, DCPCheckpointLoadingConfig), ComponentEntity("checkpoint_loading", "torch", TorchCheckpointLoading, TorchCheckpointLoadingConfig), # Progress subscriber ComponentEntity( diff --git a/tests/checkpointing/test_app_state_components_to_load.py b/tests/checkpointing/test_app_state_components_to_load.py new file mode 100644 index 000000000..acd842d86 --- /dev/null +++ b/tests/checkpointing/test_app_state_components_to_load.py @@ -0,0 +1,137 @@ +from unittest.mock import MagicMock + +import pytest +import torch.nn as nn +from torch.optim import SGD +from torch.optim.lr_scheduler import StepLR + +from modalities.checkpointing.stateful import app_state as app_state_module +from modalities.checkpointing.stateful.app_state import AppState, StatefulComponents + + +@pytest.fixture +def model() -> nn.Module: + return nn.Linear(4, 2) + + +@pytest.fixture +def optimizer(model: nn.Module) -> SGD: + return SGD(model.parameters(), lr=0.1) + + +@pytest.fixture +def lr_scheduler(optimizer: SGD) -> StepLR: + return StepLR(optimizer, step_size=1) + + +@pytest.fixture +def patched_retrievers(monkeypatch: pytest.MonkeyPatch) -> dict[StatefulComponents, MagicMock]: + """Replace each retriever's ``load_state_dict_`` with a mock so we can assert which ones were invoked.""" + mocks = { + StatefulComponents.MODEL: MagicMock(), + StatefulComponents.OPTIMIZER: MagicMock(), + StatefulComponents.LR_SCHEDULER: MagicMock(), + } + monkeypatch.setattr(app_state_module.ModelStateRetriever, "load_state_dict_", mocks[StatefulComponents.MODEL]) + monkeypatch.setattr( + app_state_module.OptimizerStateRetriever, "load_state_dict_", mocks[StatefulComponents.OPTIMIZER] + ) + monkeypatch.setattr( + app_state_module.LRSchedulerStateRetriever, "load_state_dict_", mocks[StatefulComponents.LR_SCHEDULER] + ) + return mocks + + +def _make_state_dict() -> dict: + return { + StatefulComponents.MODEL.value: {"model_payload": True}, + StatefulComponents.OPTIMIZER.value: {"optimizer_payload": True}, + StatefulComponents.LR_SCHEDULER.value: {"lr_scheduler_payload": True}, + } + + +class TestComponentsToLoad: + def test_default_without_lr_scheduler_loads_model_and_optimizer( + self, model: nn.Module, optimizer: SGD, patched_retrievers: dict[StatefulComponents, MagicMock] + ) -> None: + app_state = AppState(model=model, optimizer=optimizer) + + app_state.load_state_dict(_make_state_dict()) + + patched_retrievers[StatefulComponents.MODEL].assert_called_once() + patched_retrievers[StatefulComponents.OPTIMIZER].assert_called_once() + patched_retrievers[StatefulComponents.LR_SCHEDULER].assert_not_called() + assert app_state.is_loaded + + def test_default_with_lr_scheduler_loads_all_three( + self, + model: nn.Module, + optimizer: SGD, + lr_scheduler: StepLR, + patched_retrievers: dict[StatefulComponents, MagicMock], + ) -> None: + app_state = AppState(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler) + + app_state.load_state_dict(_make_state_dict()) + + patched_retrievers[StatefulComponents.MODEL].assert_called_once() + patched_retrievers[StatefulComponents.OPTIMIZER].assert_called_once() + patched_retrievers[StatefulComponents.LR_SCHEDULER].assert_called_once() + + @pytest.mark.parametrize( + "selected", + [ + [StatefulComponents.MODEL], + [StatefulComponents.OPTIMIZER], + [StatefulComponents.LR_SCHEDULER], + [StatefulComponents.MODEL, StatefulComponents.OPTIMIZER], + [StatefulComponents.MODEL, StatefulComponents.LR_SCHEDULER], + [], + ], + ) + def test_explicit_selection_only_loads_chosen_components( + self, + model: nn.Module, + optimizer: SGD, + lr_scheduler: StepLR, + patched_retrievers: dict[StatefulComponents, MagicMock], + selected: list[StatefulComponents], + ) -> None: + app_state = AppState(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, components_to_load=selected) + + app_state.load_state_dict(_make_state_dict()) + + for component, mock in patched_retrievers.items(): + if component in selected: + mock.assert_called_once() + else: + mock.assert_not_called() + + def test_lr_scheduler_in_components_but_no_scheduler_attached_is_skipped( + self, model: nn.Module, optimizer: SGD, patched_retrievers: dict[StatefulComponents, MagicMock] + ) -> None: + # Guards against the lr_scheduler branch firing when no scheduler is attached — the + # state_dict won't carry a scheduler entry, so the retriever must not be called. + app_state = AppState( + model=model, + optimizer=optimizer, + components_to_load=[StatefulComponents.MODEL, StatefulComponents.LR_SCHEDULER], + ) + + state_dict = _make_state_dict() + state_dict.pop(StatefulComponents.LR_SCHEDULER.value) + + app_state.load_state_dict(state_dict) + + patched_retrievers[StatefulComponents.MODEL].assert_called_once() + patched_retrievers[StatefulComponents.OPTIMIZER].assert_not_called() + patched_retrievers[StatefulComponents.LR_SCHEDULER].assert_not_called() + + def test_double_load_raises( + self, model: nn.Module, optimizer: SGD, patched_retrievers: dict[StatefulComponents, MagicMock] + ) -> None: + app_state = AppState(model=model, optimizer=optimizer) + app_state.load_state_dict(_make_state_dict()) + + with pytest.raises(RuntimeError): + app_state.load_state_dict(_make_state_dict())