Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
from __future__ import annotations

import asyncio
from collections.abc import Set
import contextlib
import dataclasses
import functools
import json
import sys
import threading
import time
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
import uuid

from absl import logging
Expand All @@ -48,8 +49,10 @@
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.multihost import multihost
from orbax.checkpoint._src.path import async_path
from orbax.checkpoint._src.path import atomicity_types
from orbax.checkpoint._src.path import format_utils
from orbax.checkpoint._src.path import types as path_types
from orbax.checkpoint._src.path.snapshot import snapshot
from orbax.checkpoint._src.serialization import limits
from orbax.checkpoint._src.serialization import memory_regulator
from orbax.checkpoint._src.serialization import ocdbt_utils
Expand Down Expand Up @@ -81,6 +84,8 @@
PYTREE_METADATA_FILE = format_utils.PYTREE_METADATA_FILE
PLACEHOLDER = type_handlers.PLACEHOLDER
PLACEHOLDER_TYPESTR = type_handlers.PLACEHOLDER_TYPESTR
TMP_DIR_SUFFIX = atomicity_types.TMP_DIR_SUFFIX
PENDING_DIR_SUFFIX = snapshot.PENDING_DIR_SUFFIX

DEFAULT_CONCURRENT_GB = 96

Expand Down Expand Up @@ -346,6 +351,136 @@ def _format_bytes(bytes_value: Optional[int]) -> str:
)


def _is_prefix(t1: Tuple[Any, ...], t2: Tuple[Any, ...]) -> bool:
"""Checks if tuple t1 is a prefix of tuple t2."""
return len(t1) < len(t2) and t2[: len(t1)] == t1


async def _get_partial_save_additions(
directory: epath.Path,
flat_item: Mapping[Any, Any],
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions,
) -> set[Any]:
"""Gets keys from the current save that are additions to the partial save.

This method checks the keys in `flat_item` against metadata from previously
completed partial saves within the same checkpoint session. It identifies
which keys represent new additions and raises an error if any key attempts
to overwrite or conflict with a key already present in the merged metadata
from prior partial saves.

Args:
directory: The temporary directory for the current partial save.
flat_item: A flattened dictionary of the PyTree being saved in the current
operation.
pytree_metadata_options: PyTreeMetadataOptions object.

Returns:
A set of keys from `flat_item` that are considered new additions.

Raises:
PartialSaveReplacementError: If any key in `flat_item` is found to be a
replacement or a conflicting entry (e.g., a prefix) of a key already
saved in a previous partial save within the same session.
ValueError: If the directory name does not end with the expected
TMP_DIR_SUFFIX.
"""
# Reconstruct the base partial path from the temporary directory.
# The temporary directory should be named
# `{checkpoint_name}.partial_save.orbax-checkpoint-tmp...`
# and we want to find all other pending saves for this checkpoint.
tmp_dir = directory.parent
if not tmp_dir.name.endswith(TMP_DIR_SUFFIX):
raise ValueError(
f'Expected temporary directory name to end with {TMP_DIR_SUFFIX}, '
f'but got {tmp_dir.name}. Partial saving requires a TemporaryPath '
'class that supports snapshots.'
)
base_name = tmp_dir.name[: -len(TMP_DIR_SUFFIX)]
partial_path = tmp_dir.parent / base_name

# Glob for metadata files written by previous partial saves in this session.
pending_dirs = await snapshot.list_pending_dirs(partial_path)
pending_metadata_files = []
for d in pending_dirs:
meta_file = d / directory.name / PYTREE_METADATA_FILE
if await async_path.exists(meta_file):
pending_metadata_files.append(meta_file)

async def get_tree_metadata(meta_file: epath.Path):
return tree_metadata.InternalTreeMetadata.from_json(
json.loads(await async_path.read_text(meta_file)),
pytree_metadata_options=pytree_metadata_options,
)

internal_metas = await asyncio.gather(
*[get_tree_metadata(meta_file) for meta_file in pending_metadata_files]
)
merge_trees = lambda a, b: a.merge(b, overwrite=True)
merged_metadata = (
functools.reduce(merge_trees, internal_metas).tree_metadata_entries
if internal_metas
else []
)
merged_tuples_set = {
tree_utils.tuple_path_from_keypath(entry.jax_keypath())
for entry in merged_metadata
}

# Check for replacements vs. additions by comparing keys from the current
# save request against the merged metadata of previous pending saves.
def _validate_key(key, merged_tuples_set=merged_tuples_set):
is_exact_match = key in merged_tuples_set
has_prefix_conflict = isinstance(key, tuple) and any(
isinstance(mt, tuple) and (_is_prefix(key, mt) or _is_prefix(mt, key))
for mt in merged_tuples_set
)
if is_exact_match or has_prefix_conflict:
raise PartialSaveReplacementError(
f'Key "{key!r}" was found in a previous partial save in this session.'
' Partial saving currently does not support REPLACEMENT.'
)
return key

additions = {_validate_key(key) for key in flat_item}

logging.info(
'[process=%d] Found the following additions during partial save: %s',
multihost.process_index(),
additions,
)
return additions


def _filter_batch_requests(
batch_requests: Sequence[_BatchRequest],
additions: Set[Any],
) -> list[_BatchRequest]:
"""Filters batch requests to include only items matching the additions."""
filtered_requests = []
for request in batch_requests:
filtered_items = []
for key, value, info, arg in zip(
request.keys, request.values, request.infos, request.args
):
for add in additions:
# Additions may be a prefix/parent of the key.
if add == key[: len(add)]:
filtered_items.append((key, value, info, arg))
if filtered_items:
keys, values, infos, args = zip(*filtered_items)
filtered_requests.append(
dataclasses.replace(
request,
keys=list(keys),
values=list(values),
infos=list(infos),
args=list(args),
)
)
return filtered_requests


class BasePyTreeCheckpointHandler(
async_checkpoint_handler.DeferredPathAsyncCheckpointHandler
):
Expand Down Expand Up @@ -570,62 +705,17 @@ async def _async_partial_save(
batch_requests: list[_BatchRequest],
param_infos: PyTree,
save_args: BasePyTreeSaveArgs,
):
value_metadata_tree = (
await self._read_metadata_file(directory)
).as_nested_tree()

tree_diff = tree_structure_utils.tree_difference(item, value_metadata_tree)

additions = set()

def _handle_diffs(keypath, diff):
keypath = tree_utils.tuple_path_from_keypath(keypath)
if diff.lhs is not None: # Leaf is present in the current item
if diff.rhs is None: # Leaf was not in the on-disk metadata
additions.add(keypath)
else: # Leaf was also in the on-disk metadata
raise PartialSaveReplacementError(
f'Key "{keypath}" was found in the on-disk PyTree metadata and'
' supplied item. Partial saving currently does not support'
' REPLACEMENT. Please reach out to the Orbax team if you need'
' this feature.'
)

jax.tree.map_with_path(
_handle_diffs,
tree_diff,
is_leaf=lambda x: isinstance(x, tree_structure_utils.Diff),
) -> Tuple[
List[asyncio.Coroutine[Any, Any, Sequence[future.Future]]],
int,
PyTree,
BasePyTreeSaveArgs,
]:
flat_item = tree_utils.to_flat_dict(item)
additions = await _get_partial_save_additions(
directory, flat_item, self._pytree_metadata_options
)

logging.info(
'[process=%d] Found the following additions during partial save: %s',
multihost.process_index(),
additions,
)

# Filter out requests that don't have any additions.
filtered_requests = []
for request in batch_requests:
filtered_items = []
for key, value, info, arg in zip(
request.keys, request.values, request.infos, request.args
):
for add in additions:
# Additions may be a prefix/parent of the key.
if add == key[: len(add)]:
filtered_items.append((key, value, info, arg))
if filtered_items:
keys, values, infos, args = zip(*filtered_items)
filtered_requests.append(
dataclasses.replace(
request,
keys=list(keys),
values=list(values),
infos=list(infos),
args=list(args),
)
)
filtered_requests = _filter_batch_requests(batch_requests, additions)

serialize_ops = []
tree_memory_size = 0
Expand Down Expand Up @@ -733,12 +823,9 @@ async def async_save(
self._type_handler_registry,
)

is_partial_save = args.partial_save_mode and await async_path.exists(
directory / PYTREE_METADATA_FILE
)
batch_requests_ready_time = time.time()
with _memory_profiler_context():
if is_partial_save:
if args.partial_save_mode:
serialize_ops, tree_memory_size, param_infos, save_args = (
await self._async_partial_save(
directory, item, batch_requests, param_infos, save_args
Expand Down Expand Up @@ -784,7 +871,7 @@ async def async_save(
custom_metadata=custom_metadata,
use_ocdbt=self._use_ocdbt,
use_zarr3=self._use_zarr3,
partial_save=is_partial_save,
partial_save=args.partial_save_mode,
),
name='write_metadata_after_commits',
)
Expand Down Expand Up @@ -1246,6 +1333,7 @@ async def _write_metadata_file(
use_zarr3: bool,
partial_save: bool,
) -> None:
del partial_save # Unused
if utils.is_primary_host(self._primary_host):
metadata_write_start_time = time.time()
path = directory / PYTREE_METADATA_FILE
Expand All @@ -1258,12 +1346,6 @@ async def _write_metadata_file(
pytree_metadata_options=self._pytree_metadata_options,
)

if partial_save:
old_metadata = await self._read_metadata_file(directory)
metadata_content = tree_metadata.InternalTreeMetadata.merge(
old_metadata, metadata_content, overwrite=True
)

logging.vlog(
1,
'Writing pytree metadata file: %s with pytree_metadata_options: %s',
Expand Down
11 changes: 6 additions & 5 deletions checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
from orbax.checkpoint._src.path import utils as ocp_path_utils


SNAPSHOTTING_TIME = "snapshotting_time"
PENDING_DIR_SUFFIX = ".pending_"

PENDING_DIR_SUFFIX = ".pending"


def get_pending_dir_name(source_name: str) -> str:
return f"{source_name}{PENDING_DIR_SUFFIX}{uuid.uuid4().hex}"
return (
f"{source_name}{PENDING_DIR_SUFFIX}_{time.time_ns()}_{uuid.uuid4().hex}"
)


def get_uuid_from_pending_dir_name(pending_dir_name: str) -> str:
Expand Down Expand Up @@ -169,8 +171,7 @@ async def replace_source(self) -> None:

if not await async_path.exists(self._snapshot):
raise FileNotFoundError(f"Snapshot does not exist: {self._snapshot}")
if not await async_path.exists(self._source):
await async_path.mkdir(self._source, parents=True, exist_ok=True)
await async_path.mkdir(self._source, parents=True, exist_ok=True)
# Move files from inside the tmp snapshot into the original source
# directory under a pending suffix. This is to avoid potentially wiping
# out previous files.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,11 @@ def __init__(
)

async def _finalize(self, directory: path_types.Path):
# Keep non-finalized checkpoint state during partial saves to be merged
# later during partial save finalization.
if self._partial_save_mode:
return

if multihost.is_primary_host(self._multiprocessing_options.primary_host):
await self._handler_impl._finalize_async(directory) # pylint: disable=protected-access

Expand Down
Loading
Loading