Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,18 @@ def load_optimizer_checkpoint_(self, optimizer: Optimizer, model: FSDP, file_pat
class DCPCheckpointLoading(DistributedCheckpointLoadingIF):
"""Distributed checkpoint loading interface for loading PyTorch models and optimizer checkpoints."""

def __init__(self, global_rank: int):
def __init__(self, global_rank: int, allow_partial_load: bool = False):
"""
Initializes the DCPCheckpointLoading object.

Args:
global_rank (int): The global rank of the process.

allow_partial_load (bool, optional): Whether to allow partial loading of the checkpoint. Defaults to True.
Returns:
None
"""
self._global_rank = global_rank
self._allow_partial_load = allow_partial_load

@torch.no_grad()
def load_checkpoint_(self, app_state: AppState, checkpoint_dir_path: Path):
Expand All @@ -129,5 +130,6 @@ def load_checkpoint_(self, app_state: AppState, checkpoint_dir_path: Path):
dcp.load(
state_dict={"app": app_state},
checkpoint_id=checkpoint_dir_path,
planner=dcp.DefaultLoadPlanner(allow_partial_load=self._allow_partial_load),
)
get_logger().info(f"Distributed checkpoint loaded on rank {self._global_rank}.")
30 changes: 23 additions & 7 deletions src/modalities/checkpointing/stateful/app_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ class AppState(Stateful):
"""

def __init__(
self, model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None
self,
model: nn.Module | list[nn.Module],
optimizer: Optimizer,
lr_scheduler: Optional[LRScheduler] = None,
components_to_load: list[StatefulComponents] | None = None,
):
"""Initializes the AppState object.

Expand All @@ -46,12 +50,22 @@ def __init__(
a non-sharded model, FSDP1 or FSDP2 model.
optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer.
lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None.
components_to_load (list[StatefulComponents] | None, optional): The list of components to load from the
checkpoint. If None, all components are loaded. Defaults to None.
"""
self._model_parts = list(model) if isinstance(model, list) else [model]
self._optimizer = optimizer
self._lr_scheduler = lr_scheduler
self._is_loaded = False

# policy for which components to load from the checkpoint. If None, defaults to loading all components.
if components_to_load is None:
self._components_to_load = [StatefulComponents.MODEL, StatefulComponents.OPTIMIZER]
if lr_scheduler is not None:
self._components_to_load.append(StatefulComponents.LR_SCHEDULER)
else:
self._components_to_load = components_to_load

@property
def is_loaded(self) -> bool:
"""Returns whether the state dict has been loaded.
Expand Down Expand Up @@ -106,12 +120,14 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"Cannot call load_state_dict twice on the same AppState object. " "State dict has already been loaded."
)

ModelStateRetriever.load_state_dict_(app_state=self, state_dict=state_dict[StatefulComponents.MODEL.value])
OptimizerStateRetriever.load_state_dict_(
app_state=self,
state_dict=state_dict[StatefulComponents.OPTIMIZER.value],
)
if self._lr_scheduler is not None:
if StatefulComponents.MODEL in self._components_to_load:
ModelStateRetriever.load_state_dict_(app_state=self, state_dict=state_dict[StatefulComponents.MODEL.value])
if StatefulComponents.OPTIMIZER in self._components_to_load:
OptimizerStateRetriever.load_state_dict_(
app_state=self,
state_dict=state_dict[StatefulComponents.OPTIMIZER.value],
)
if self._lr_scheduler is not None and StatefulComponents.LR_SCHEDULER in self._components_to_load:
LRSchedulerStateRetriever.load_state_dict_(
app_state=self, state_dict=state_dict[StatefulComponents.LR_SCHEDULER.value]
)
Expand Down
27 changes: 21 additions & 6 deletions src/modalities/checkpointing/stateful/app_state_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
from torch.optim.lr_scheduler import LRScheduler

from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import DCPCheckpointLoading
from modalities.checkpointing.stateful.app_state import AppState
from modalities.checkpointing.stateful.app_state import AppState, StatefulComponents


class AppStateFactory:
"""Factory class to create AppState objects."""

@staticmethod
def get_raw_app_state(
model: nn.Module | list[nn.Module], optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None
model: nn.Module | list[nn.Module],
optimizer: Optimizer,
lr_scheduler: Optional[LRScheduler] = None,
components_to_load: list[StatefulComponents] | None = None,
) -> AppState:
"""Creates a new (non-checkpoint loaded) AppState object from an instantiated
model, optimizer, and optional learning rate scheduler.
Expand All @@ -25,24 +28,35 @@ def get_raw_app_state(
a non-sharded model, FSDP1 or FSDP2 model.
optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer.
lr_scheduler (Optional[LRScheduler], optional): Lr scheduler used during training. Defaults to None.
components_to_load (list[StatefulComponents] | None, optional): Subset of components that should
be restored from a checkpoint when ``load_state_dict`` is later invoked. If None, all
available components are loaded. Defaults to None.

Returns:
AppState: The AppState object.
"""
app_state = AppState(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler)
app_state = AppState(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
components_to_load=components_to_load,
)
return app_state

@staticmethod
def get_dcp_checkpointed_app_state_(
raw_app_state: AppState,
checkpoint_dir_path: Path,
allow_partial_load: bool = True,
) -> AppState:
"""Loads the checkpointed state dict into the raw AppState object
(i.e., non-checkpoint loaded AppState) in-place.

Args:
raw_app_state (AppState): The raw AppState object.
raw_app_state (AppState): The raw AppState object. Its ``components_to_load`` policy
determines which components are restored.
checkpoint_dir_path (Path): The path to the checkpoint directory.
allow_partial_load (bool, optional): Whether to allow partial loading of the checkpoint. Defaults to True.

Raises:
RuntimeError: Raises an error if the state dict has already been loaded.
Expand All @@ -52,8 +66,9 @@ def get_dcp_checkpointed_app_state_(
"""
if raw_app_state.is_loaded:
raise RuntimeError(
"Cannot call load_state_dict twice on the same AppState object. " "State dict has already been loaded."
"Cannot call load_state_dict twice on the same AppState object. State dict has already been loaded."
)
cp_loading = DCPCheckpointLoading(global_rank=dist.get_rank())

cp_loading = DCPCheckpointLoading(global_rank=dist.get_rank(), allow_partial_load=allow_partial_load)
cp_loading.load_checkpoint_(app_state=raw_app_state, checkpoint_dir_path=checkpoint_dir_path)
return raw_app_state
8 changes: 6 additions & 2 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from transformers import LlamaTokenizer as LlamaTokenizerFast
from typing_extensions import deprecated

from modalities.checkpointing.stateful.app_state import StatefulComponents
from modalities.config.lookup_enum import LookupEnum
from modalities.config.pydantic_if_types import (
PydanticAppStateType,
Expand Down Expand Up @@ -124,8 +125,9 @@ def parse_sharding_strategy_by_name(cls, name: str) -> ShardingStrategy:
return parse_enum_by_name(name=name, enum_type=ShardingStrategy)


class DCPCheckpointLoadingConfig(BaseModel):
global_rank: Annotated[int, Field(strict=True, ge=0)]
# class DCPCheckpointLoadingConfig(BaseModel):
# global_rank: Annotated[int, Field(strict=True, ge=0)]
# allow_partial_load: bool = True


class FSDP1CheckpointSavingConfig(BaseModel):
Expand Down Expand Up @@ -382,11 +384,13 @@ class RawAppStateConfig(BaseModel):
model: PydanticPytorchModuleOrListType
optimizer: PydanticOptimizerIFType
lr_scheduler: Optional[PydanticLRSchedulerIFType] = None
components_to_load: Optional[list[StatefulComponents]] = None


class DCPAppStateConfig(BaseModel):
raw_app_state: PydanticAppStateType
checkpoint_dir_path: Path
allow_partial_load: bool = False


class PreTrainedHFTokenizerConfig(BaseModel):
Expand Down
5 changes: 2 additions & 3 deletions src/modalities/registry/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
SaveEveryKStepsCheckpointingStrategy,
SaveKMostRecentCheckpointsStrategy,
)
from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import DCPCheckpointLoading, FSDP1CheckpointLoading
from modalities.checkpointing.fsdp.fsdp_checkpoint_loading import FSDP1CheckpointLoading
from modalities.checkpointing.fsdp.fsdp_checkpoint_saving import DCPCheckpointSaving, FSDP1CheckpointSaving
from modalities.checkpointing.stateful.app_state_factory import AppStateFactory
from modalities.checkpointing.torch.torch_checkpoint_loading import TorchCheckpointLoading
Expand All @@ -29,7 +29,6 @@
ConstantLRSchedulerConfig,
CosineAnnealingLRSchedulerConfig,
DCPAppStateConfig,
DCPCheckpointLoadingConfig,
DCPCheckpointSavingConfig,
DebuggingEnrichedModelConfig,
DistributedSamplerConfig,
Expand Down Expand Up @@ -358,7 +357,7 @@ class ComponentEntity:
ComponentEntity("checkpoint_saving_execution", "dcp", DCPCheckpointSaving, DCPCheckpointSavingConfig),
# checkpoint loading
ComponentEntity("checkpoint_loading", "fsdp1", FSDP1CheckpointLoading, FSDP1CheckpointLoadingConfig),
ComponentEntity("checkpoint_loading", "dcp", DCPCheckpointLoading, DCPCheckpointLoadingConfig),
# ComponentEntity("checkpoint_loading", "dcp", DCPCheckpointLoading, DCPCheckpointLoadingConfig),
ComponentEntity("checkpoint_loading", "torch", TorchCheckpointLoading, TorchCheckpointLoadingConfig),
# Progress subscriber
ComponentEntity(
Expand Down
137 changes: 137 additions & 0 deletions tests/checkpointing/test_app_state_components_to_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from unittest.mock import MagicMock

import pytest
import torch.nn as nn
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR

from modalities.checkpointing.stateful import app_state as app_state_module
from modalities.checkpointing.stateful.app_state import AppState, StatefulComponents


@pytest.fixture
def model() -> nn.Module:
return nn.Linear(4, 2)


@pytest.fixture
def optimizer(model: nn.Module) -> SGD:
return SGD(model.parameters(), lr=0.1)


@pytest.fixture
def lr_scheduler(optimizer: SGD) -> StepLR:
return StepLR(optimizer, step_size=1)


@pytest.fixture
def patched_retrievers(monkeypatch: pytest.MonkeyPatch) -> dict[StatefulComponents, MagicMock]:
"""Replace each retriever's ``load_state_dict_`` with a mock so we can assert which ones were invoked."""
mocks = {
StatefulComponents.MODEL: MagicMock(),
StatefulComponents.OPTIMIZER: MagicMock(),
StatefulComponents.LR_SCHEDULER: MagicMock(),
}
monkeypatch.setattr(app_state_module.ModelStateRetriever, "load_state_dict_", mocks[StatefulComponents.MODEL])
monkeypatch.setattr(
app_state_module.OptimizerStateRetriever, "load_state_dict_", mocks[StatefulComponents.OPTIMIZER]
)
monkeypatch.setattr(
app_state_module.LRSchedulerStateRetriever, "load_state_dict_", mocks[StatefulComponents.LR_SCHEDULER]
)
return mocks


def _make_state_dict() -> dict:
return {
StatefulComponents.MODEL.value: {"model_payload": True},
StatefulComponents.OPTIMIZER.value: {"optimizer_payload": True},
StatefulComponents.LR_SCHEDULER.value: {"lr_scheduler_payload": True},
}


class TestComponentsToLoad:
def test_default_without_lr_scheduler_loads_model_and_optimizer(
self, model: nn.Module, optimizer: SGD, patched_retrievers: dict[StatefulComponents, MagicMock]
) -> None:
app_state = AppState(model=model, optimizer=optimizer)

app_state.load_state_dict(_make_state_dict())

patched_retrievers[StatefulComponents.MODEL].assert_called_once()
patched_retrievers[StatefulComponents.OPTIMIZER].assert_called_once()
patched_retrievers[StatefulComponents.LR_SCHEDULER].assert_not_called()
assert app_state.is_loaded

def test_default_with_lr_scheduler_loads_all_three(
self,
model: nn.Module,
optimizer: SGD,
lr_scheduler: StepLR,
patched_retrievers: dict[StatefulComponents, MagicMock],
) -> None:
app_state = AppState(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler)

app_state.load_state_dict(_make_state_dict())

patched_retrievers[StatefulComponents.MODEL].assert_called_once()
patched_retrievers[StatefulComponents.OPTIMIZER].assert_called_once()
patched_retrievers[StatefulComponents.LR_SCHEDULER].assert_called_once()

@pytest.mark.parametrize(
"selected",
[
[StatefulComponents.MODEL],
[StatefulComponents.OPTIMIZER],
[StatefulComponents.LR_SCHEDULER],
[StatefulComponents.MODEL, StatefulComponents.OPTIMIZER],
[StatefulComponents.MODEL, StatefulComponents.LR_SCHEDULER],
[],
],
)
def test_explicit_selection_only_loads_chosen_components(
self,
model: nn.Module,
optimizer: SGD,
lr_scheduler: StepLR,
patched_retrievers: dict[StatefulComponents, MagicMock],
selected: list[StatefulComponents],
) -> None:
app_state = AppState(model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, components_to_load=selected)

app_state.load_state_dict(_make_state_dict())

for component, mock in patched_retrievers.items():
if component in selected:
mock.assert_called_once()
else:
mock.assert_not_called()

def test_lr_scheduler_in_components_but_no_scheduler_attached_is_skipped(
self, model: nn.Module, optimizer: SGD, patched_retrievers: dict[StatefulComponents, MagicMock]
) -> None:
# Guards against the lr_scheduler branch firing when no scheduler is attached — the
# state_dict won't carry a scheduler entry, so the retriever must not be called.
app_state = AppState(
model=model,
optimizer=optimizer,
components_to_load=[StatefulComponents.MODEL, StatefulComponents.LR_SCHEDULER],
)

state_dict = _make_state_dict()
state_dict.pop(StatefulComponents.LR_SCHEDULER.value)

app_state.load_state_dict(state_dict)

patched_retrievers[StatefulComponents.MODEL].assert_called_once()
patched_retrievers[StatefulComponents.OPTIMIZER].assert_not_called()
patched_retrievers[StatefulComponents.LR_SCHEDULER].assert_not_called()

def test_double_load_raises(
self, model: nn.Module, optimizer: SGD, patched_retrievers: dict[StatefulComponents, MagicMock]
) -> None:
app_state = AppState(model=model, optimizer=optimizer)
app_state.load_state_dict(_make_state_dict())

with pytest.raises(RuntimeError):
app_state.load_state_dict(_make_state_dict())
Loading