diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index 014a88527..d403240c6 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -22,6 +22,7 @@ from __future__ import annotations import asyncio +from collections.abc import Set import contextlib import dataclasses import functools @@ -29,7 +30,7 @@ import sys import threading import time -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union import uuid from absl import logging @@ -48,8 +49,10 @@ from orbax.checkpoint._src.metadata import tree as tree_metadata from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.path import async_path +from orbax.checkpoint._src.path import atomicity_types from orbax.checkpoint._src.path import format_utils from orbax.checkpoint._src.path import types as path_types +from orbax.checkpoint._src.path.snapshot import snapshot from orbax.checkpoint._src.serialization import limits from orbax.checkpoint._src.serialization import memory_regulator from orbax.checkpoint._src.serialization import ocdbt_utils @@ -81,6 +84,8 @@ PYTREE_METADATA_FILE = format_utils.PYTREE_METADATA_FILE PLACEHOLDER = type_handlers.PLACEHOLDER PLACEHOLDER_TYPESTR = type_handlers.PLACEHOLDER_TYPESTR +TMP_DIR_SUFFIX = atomicity_types.TMP_DIR_SUFFIX +PENDING_DIR_SUFFIX = snapshot.PENDING_DIR_SUFFIX DEFAULT_CONCURRENT_GB = 96 @@ -346,6 +351,136 @@ def _format_bytes(bytes_value: Optional[int]) -> str: ) +def _is_prefix(t1: Tuple[Any, ...], t2: Tuple[Any, ...]) -> bool: + """Checks if tuple t1 is a prefix of tuple t2.""" + return len(t1) < len(t2) and t2[: len(t1)] == t1 + + +async def _get_partial_save_additions( + directory: epath.Path, + flat_item: Mapping[Any, Any], + pytree_metadata_options: tree_metadata.PyTreeMetadataOptions, +) -> set[Any]: + """Gets keys from the current save that are additions to the partial save. + + This method checks the keys in `flat_item` against metadata from previously + completed partial saves within the same checkpoint session. It identifies + which keys represent new additions and raises an error if any key attempts + to overwrite or conflict with a key already present in the merged metadata + from prior partial saves. + + Args: + directory: The temporary directory for the current partial save. + flat_item: A flattened dictionary of the PyTree being saved in the current + operation. + pytree_metadata_options: PyTreeMetadataOptions object. + + Returns: + A set of keys from `flat_item` that are considered new additions. + + Raises: + PartialSaveReplacementError: If any key in `flat_item` is found to be a + replacement or a conflicting entry (e.g., a prefix) of a key already + saved in a previous partial save within the same session. + ValueError: If the directory name does not end with the expected + TMP_DIR_SUFFIX. + """ + # Reconstruct the base partial path from the temporary directory. + # The temporary directory should be named + # `{checkpoint_name}.partial_save.orbax-checkpoint-tmp...` + # and we want to find all other pending saves for this checkpoint. + tmp_dir = directory.parent + if not tmp_dir.name.endswith(TMP_DIR_SUFFIX): + raise ValueError( + f'Expected temporary directory name to end with {TMP_DIR_SUFFIX}, ' + f'but got {tmp_dir.name}. Partial saving requires a TemporaryPath ' + 'class that supports snapshots.' + ) + base_name = tmp_dir.name[: -len(TMP_DIR_SUFFIX)] + partial_path = tmp_dir.parent / base_name + + # Glob for metadata files written by previous partial saves in this session. + pending_dirs = await snapshot.list_pending_dirs(partial_path) + pending_metadata_files = [] + for d in pending_dirs: + meta_file = d / directory.name / PYTREE_METADATA_FILE + if await async_path.exists(meta_file): + pending_metadata_files.append(meta_file) + + async def get_tree_metadata(meta_file: epath.Path): + return tree_metadata.InternalTreeMetadata.from_json( + json.loads(await async_path.read_text(meta_file)), + pytree_metadata_options=pytree_metadata_options, + ) + + internal_metas = await asyncio.gather( + *[get_tree_metadata(meta_file) for meta_file in pending_metadata_files] + ) + merge_trees = lambda a, b: a.merge(b, overwrite=True) + merged_metadata = ( + functools.reduce(merge_trees, internal_metas).tree_metadata_entries + if internal_metas + else [] + ) + merged_tuples_set = { + tree_utils.tuple_path_from_keypath(entry.jax_keypath()) + for entry in merged_metadata + } + + # Check for replacements vs. additions by comparing keys from the current + # save request against the merged metadata of previous pending saves. + def _validate_key(key, merged_tuples_set=merged_tuples_set): + is_exact_match = key in merged_tuples_set + has_prefix_conflict = isinstance(key, tuple) and any( + isinstance(mt, tuple) and (_is_prefix(key, mt) or _is_prefix(mt, key)) + for mt in merged_tuples_set + ) + if is_exact_match or has_prefix_conflict: + raise PartialSaveReplacementError( + f'Key "{key!r}" was found in a previous partial save in this session.' + ' Partial saving currently does not support REPLACEMENT.' + ) + return key + + additions = {_validate_key(key) for key in flat_item} + + logging.info( + '[process=%d] Found the following additions during partial save: %s', + multihost.process_index(), + additions, + ) + return additions + + +def _filter_batch_requests( + batch_requests: Sequence[_BatchRequest], + additions: Set[Any], +) -> list[_BatchRequest]: + """Filters batch requests to include only items matching the additions.""" + filtered_requests = [] + for request in batch_requests: + filtered_items = [] + for key, value, info, arg in zip( + request.keys, request.values, request.infos, request.args + ): + for add in additions: + # Additions may be a prefix/parent of the key. + if add == key[: len(add)]: + filtered_items.append((key, value, info, arg)) + if filtered_items: + keys, values, infos, args = zip(*filtered_items) + filtered_requests.append( + dataclasses.replace( + request, + keys=list(keys), + values=list(values), + infos=list(infos), + args=list(args), + ) + ) + return filtered_requests + + class BasePyTreeCheckpointHandler( async_checkpoint_handler.DeferredPathAsyncCheckpointHandler ): @@ -570,62 +705,17 @@ async def _async_partial_save( batch_requests: list[_BatchRequest], param_infos: PyTree, save_args: BasePyTreeSaveArgs, - ): - value_metadata_tree = ( - await self._read_metadata_file(directory) - ).as_nested_tree() - - tree_diff = tree_structure_utils.tree_difference(item, value_metadata_tree) - - additions = set() - - def _handle_diffs(keypath, diff): - keypath = tree_utils.tuple_path_from_keypath(keypath) - if diff.lhs is not None: # Leaf is present in the current item - if diff.rhs is None: # Leaf was not in the on-disk metadata - additions.add(keypath) - else: # Leaf was also in the on-disk metadata - raise PartialSaveReplacementError( - f'Key "{keypath}" was found in the on-disk PyTree metadata and' - ' supplied item. Partial saving currently does not support' - ' REPLACEMENT. Please reach out to the Orbax team if you need' - ' this feature.' - ) - - jax.tree.map_with_path( - _handle_diffs, - tree_diff, - is_leaf=lambda x: isinstance(x, tree_structure_utils.Diff), + ) -> Tuple[ + List[asyncio.Coroutine[Any, Any, Sequence[future.Future]]], + int, + PyTree, + BasePyTreeSaveArgs, + ]: + flat_item = tree_utils.to_flat_dict(item) + additions = await _get_partial_save_additions( + directory, flat_item, self._pytree_metadata_options ) - - logging.info( - '[process=%d] Found the following additions during partial save: %s', - multihost.process_index(), - additions, - ) - - # Filter out requests that don't have any additions. - filtered_requests = [] - for request in batch_requests: - filtered_items = [] - for key, value, info, arg in zip( - request.keys, request.values, request.infos, request.args - ): - for add in additions: - # Additions may be a prefix/parent of the key. - if add == key[: len(add)]: - filtered_items.append((key, value, info, arg)) - if filtered_items: - keys, values, infos, args = zip(*filtered_items) - filtered_requests.append( - dataclasses.replace( - request, - keys=list(keys), - values=list(values), - infos=list(infos), - args=list(args), - ) - ) + filtered_requests = _filter_batch_requests(batch_requests, additions) serialize_ops = [] tree_memory_size = 0 @@ -733,12 +823,9 @@ async def async_save( self._type_handler_registry, ) - is_partial_save = args.partial_save_mode and await async_path.exists( - directory / PYTREE_METADATA_FILE - ) batch_requests_ready_time = time.time() with _memory_profiler_context(): - if is_partial_save: + if args.partial_save_mode: serialize_ops, tree_memory_size, param_infos, save_args = ( await self._async_partial_save( directory, item, batch_requests, param_infos, save_args @@ -784,7 +871,7 @@ async def async_save( custom_metadata=custom_metadata, use_ocdbt=self._use_ocdbt, use_zarr3=self._use_zarr3, - partial_save=is_partial_save, + partial_save=args.partial_save_mode, ), name='write_metadata_after_commits', ) @@ -1246,6 +1333,7 @@ async def _write_metadata_file( use_zarr3: bool, partial_save: bool, ) -> None: + del partial_save # Unused if utils.is_primary_host(self._primary_host): metadata_write_start_time = time.time() path = directory / PYTREE_METADATA_FILE @@ -1258,12 +1346,6 @@ async def _write_metadata_file( pytree_metadata_options=self._pytree_metadata_options, ) - if partial_save: - old_metadata = await self._read_metadata_file(directory) - metadata_content = tree_metadata.InternalTreeMetadata.merge( - old_metadata, metadata_content, overwrite=True - ) - logging.vlog( 1, 'Writing pytree metadata file: %s with pytree_metadata_options: %s', diff --git a/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py b/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py index d22b9ba9b..9cd558170 100644 --- a/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py +++ b/checkpoint/orbax/checkpoint/_src/path/snapshot/snapshot.py @@ -26,12 +26,14 @@ from orbax.checkpoint._src.path import utils as ocp_path_utils -SNAPSHOTTING_TIME = "snapshotting_time" -PENDING_DIR_SUFFIX = ".pending_" + +PENDING_DIR_SUFFIX = ".pending" def get_pending_dir_name(source_name: str) -> str: - return f"{source_name}{PENDING_DIR_SUFFIX}{uuid.uuid4().hex}" + return ( + f"{source_name}{PENDING_DIR_SUFFIX}_{time.time_ns()}_{uuid.uuid4().hex}" + ) def get_uuid_from_pending_dir_name(pending_dir_name: str) -> str: @@ -169,8 +171,7 @@ async def replace_source(self) -> None: if not await async_path.exists(self._snapshot): raise FileNotFoundError(f"Snapshot does not exist: {self._snapshot}") - if not await async_path.exists(self._source): - await async_path.mkdir(self._source, parents=True, exist_ok=True) + await async_path.mkdir(self._source, parents=True, exist_ok=True) # Move files from inside the tmp snapshot into the original source # directory under a pending suffix. This is to avoid potentially wiping # out previous files. diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py index 703e38922..7ecc0bda6 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/handlers/pytree_handler.py @@ -319,6 +319,11 @@ def __init__( ) async def _finalize(self, directory: path_types.Path): + # Keep non-finalized checkpoint state during partial saves to be merged + # later during partial save finalization. + if self._partial_save_mode: + return + if multihost.is_primary_host(self._multiprocessing_options.primary_host): await self._handler_impl._finalize_async(directory) # pylint: disable=protected-access diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_multiprocess_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_multiprocess_test.py index 71a4e9ba3..4b9529d61 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_multiprocess_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/layout/orbax_layout_multiprocess_test.py @@ -23,6 +23,8 @@ from etils import epath from orbax.checkpoint import test_utils from orbax.checkpoint._src.metadata import step_metadata_serialization +from orbax.checkpoint._src.path import atomicity_types +from orbax.checkpoint._src.path.snapshot import snapshot from orbax.checkpoint._src.testing import multiprocess_test from orbax.checkpoint._src.tree import structure_utils as tree_structure_utils from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler @@ -91,11 +93,17 @@ def save( *, partial_save: bool = False, ): + if partial_save: + final_dir = directory + directory = ( + directory.parent / f'{directory.name}{atomicity_types.TMP_DIR_SUFFIX}' + ) + test_utils.sync_global_processes('CompositeHandlerTest:save:start') if multihost.is_primary_host(0): - directory.mkdir(parents=False, exist_ok=partial_save) + directory.mkdir(parents=True, exist_ok=True) for k in checkpointables: - (directory / k).mkdir(parents=False, exist_ok=partial_save) + (directory / k).mkdir(parents=True, exist_ok=True) test_utils.sync_global_processes('CompositeHandlerTest:save:mkdir') async def _save(): @@ -110,20 +118,8 @@ async def _save(): for name in checkpointables.keys() } - checkpoint_metadata_path = ( - metadata_serialization.checkpoint_metadata_file_path(directory) - ) - if partial_save and checkpoint_metadata_path.exists(): - checkpoint_metadata = await orbax_layout.read_checkpoint_metadata( - directory - ) - old_handler_typestrs = checkpoint_metadata.item_handlers - handler_typestrs = old_handler_typestrs | handler_typestrs - await multihost.sync_global_processes( - 'CompositeHandlerTest:save:checkpoint_metadata_read', - operation_id='op', - processes=None, - ) + # For partial save in this test, we skip reading existing global metadata + # here since it will be merged during finalize, just like real execution. # Metadata expected to be created outside the handler. if multihost.is_primary_host(0): @@ -148,6 +144,11 @@ async def _save(): ) await awaitable + if partial_save and multihost.is_primary_host(0): + final_dir.mkdir(parents=True, exist_ok=True) + pending_dir = final_dir / snapshot.get_pending_dir_name(final_dir.name) + directory.rename(pending_dir) + asyncio.run(_save()) test_utils.sync_global_processes('CompositeHandlerTest:save:complete') @@ -360,7 +361,6 @@ def test_partial_save_and_finalize(self, finalize_with_partial_path: bool): layout, partial_path, first_save_checkpointables, partial_save=True ) self.assertTrue(partial_path.exists()) - self.assertTrue((partial_path / ORBAX_CHECKPOINT_INDICATOR_FILE).exists()) self.save( layout, @@ -369,14 +369,6 @@ def test_partial_save_and_finalize(self, finalize_with_partial_path: bool): partial_save=True, ) self.assertTrue(partial_path.exists()) - self.assertTrue((partial_path / ORBAX_CHECKPOINT_INDICATOR_FILE).exists()) - - restored_checkpointables = self.load( - layout, partial_path, merged_checkpointables - ) - test_utils.assert_tree_equal( - self, restored_checkpointables, merged_checkpointables - ) partial_saving.finalize( partial_path if finalize_with_partial_path else final_path diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py index 3e705f5b8..e3e693529 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/partial/saving.py @@ -14,19 +14,30 @@ """Defines free-function interface for partial saving and finalizing.""" +import ast import asyncio import dataclasses -from typing import Awaitable +import itertools +import json +import logging +from typing import Any, Awaitable, Callable +from etils import epath from orbax.checkpoint._src import asyncio_utils +from orbax.checkpoint._src.handlers import base_pytree_checkpoint_handler +from orbax.checkpoint._src.metadata import tree as tree_metadata from orbax.checkpoint._src.path import async_path +from orbax.checkpoint._src.path import format_utils from orbax.checkpoint._src.path import utils as ocp_path_utils +from orbax.checkpoint._src.path.snapshot import snapshot +from orbax.checkpoint._src.tree import structure_utils as tree_structure_utils from orbax.checkpoint.experimental.v1._src.context import context as context_lib from orbax.checkpoint.experimental.v1._src.handlers import global_registration # pylint: disable=unused-import from orbax.checkpoint.experimental.v1._src.handlers import pytree_handler from orbax.checkpoint.experimental.v1._src.handlers import stateful_checkpointable_handler from orbax.checkpoint.experimental.v1._src.handlers import types as handler_types from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout +from orbax.checkpoint.experimental.v1._src.metadata import serialization as metadata_serialization from orbax.checkpoint.experimental.v1._src.partial import path as partial_path_lib from orbax.checkpoint.experimental.v1._src.path import types as path_types from orbax.checkpoint.experimental.v1._src.saving import execution @@ -36,6 +47,7 @@ STATE_CHECKPOINTABLE_KEY = checkpoint_layout.STATE_CHECKPOINTABLE_KEY +CHECKPOINT_METADATA_FILENAME = metadata_serialization._CHECKPOINT_METADATA_FILENAME # pylint: disable=protected-access StatefulCheckpointableHandler = ( stateful_checkpointable_handler.StatefulCheckpointableHandler @@ -229,6 +241,258 @@ def save_async( ) +async def _read_first_metadata( + pending_dirs: list[epath.Path], +) -> tree_metadata.InternalTreeMetadata | None: + """Reads metadata from the first pending directory.""" + if not pending_dirs: + return None + + for item in await async_path.iterdir(pending_dirs[0]): + if not await async_path.is_dir(item): + continue + first_meta_path = item / format_utils.PYTREE_METADATA_FILE + if await async_path.exists(first_meta_path): + try: + return tree_metadata.InternalTreeMetadata.from_json( + json.loads(await async_path.read_text(first_meta_path)), + pytree_metadata_options=tree_metadata.PYTREE_METADATA_OPTIONS, + ) + except json.JSONDecodeError as e: + raise ValueError( + 'Failed to read metadata from first metadata file' + f' {first_meta_path}.' + ) from e + return None + + +def _is_prefix(t1: tuple[str, ...], t2: tuple[str, ...]) -> bool: + return len(t1) < len(t2) and t2[: len(t1)] == t1 + + +def _filter_conflicting_keys(d: dict[str, Any]) -> dict[str, Any]: + """Filters metadata keys that conflict due to parent-child relationships. + + When merging metadata from multiple partial saves, we might encounter + conflicting entries. For example, one partial save might save 'a/b' as a + leaf, while another saves 'a/b/c' as a leaf. This is a conflict because + 'a/b' cannot be both a leaf and an intermediate node containing 'c'. This + function resolves the conflict by removing metadata for 'a/b', keeping + 'a/b/c', and implicitly treating 'a/b' as an intermediate node. + + Args: + d: A dictionary of metadata. + + Returns: + The filtered metadata dictionary. + """ + keys = list(d.keys()) + to_remove = set() + + parsed_keys = {} + for k in keys: + try: + parsed_keys[k] = ast.literal_eval(k) + except (ValueError, SyntaxError): + parsed_keys[k] = k + + for k1, k2 in itertools.permutations(keys, 2): + t1, t2 = parsed_keys[k1], parsed_keys[k2] + if isinstance(t1, tuple) and isinstance(t2, tuple): + if _is_prefix(t1, t2): + to_remove.add(k1) + elif isinstance(k1, str) and isinstance(k2, str): + if k2.startswith((k1 + '.', k1 + '/')): + to_remove.add(k1) + + for k in to_remove: + del d[k] + return d + + +async def _rename_or_merge_json( + src: epath.Path, dst: epath.Path, merge_fn: Callable[[Any, Any], Any] +): + """Tries to rename src to dst, otherwise merges them as JSONs using merge_fn.""" + try: + await async_path.rename(src, dst) + except FileExistsError: + pass + else: + return + + src_meta = json.loads(await async_path.read_text(src)) + dst_meta = json.loads(await async_path.read_text(dst)) + + merged_meta = merge_fn(src_meta, dst_meta) + + await async_path.write_text(dst, json.dumps(merged_meta)) + await async_path.unlink(src) + + +async def _merge_pytree_metadata(src_item: epath.Path, dst_item: epath.Path): + """Merges PyTree metadata files (_METADATA or _sharding).""" + + def _merge_fn(src_meta, dst_meta): + merged = tree_structure_utils.merge_trees( + dst_meta, src_meta, overwrite=True + ) + if 'tree_metadata' in merged: + merged['tree_metadata'] = _filter_conflicting_keys( + merged['tree_metadata'] + ) + return merged + + await _rename_or_merge_json(src_item, dst_item, _merge_fn) + + +async def _rename_ocdbt_process_dir( + item: epath.Path, pytree_dst: epath.Path, uuid_str: str +): + """Renames an ocdbt.process_ directory to avoid collisions across partial saves.""" + # To avoid collisions across different partial save pending directories, + # we append the pending dir's UUID to the original process ID. + # We must avoid using '_' in the new ID because `ocdbt_utils.py` splits + # the directory name by '_' to extract the process ID. + new_name = f'{item.name}{uuid_str.replace("-", "")}' + await async_path.rename(item, pytree_dst / new_name) + + +async def _merge_array_metadatas(src_dir: epath.Path, dst_dir: epath.Path): + """Merges array_metadatas JSON files (process_0, process_1, etc.).""" + await async_path.mkdir(dst_dir, parents=True, exist_ok=True) + + async def _process_child(src_child: epath.Path): + dst_child = dst_dir / src_child.name + + def _merge_fn(src_meta, dst_meta): + src_arr_meta = src_meta.get('array_metadatas', []) + dst_arr_meta = dst_meta.get('array_metadatas', []) + dst_arr_meta.extend(src_arr_meta) + dst_meta['array_metadatas'] = dst_arr_meta + return dst_meta + + await _rename_or_merge_json(src_child, dst_child, _merge_fn) + + await asyncio.gather(*[ + _process_child(src_child) + for src_child in await async_path.iterdir(src_dir) + ]) + + +async def _recursive_merge(src: epath.Path, dst: epath.Path): + """Recursively merges src into dst.""" + if not await async_path.exists(src): + return + + try: + await async_path.rename(src, dst) + except FileExistsError: + pass + else: + return + + if await async_path.is_dir(src): + items = await async_path.iterdir(src) + await asyncio.gather( + *[_recursive_merge(item, dst / item.name) for item in items] + ) + await async_path.rmtree(src) + return + + logging.warning( + 'File collision on %s during finalize. Overwriting destination file.', + src.name, + ) + if await async_path.is_dir(dst): + await async_path.rmtree(dst) + else: + await async_path.unlink(dst) + await async_path.rename(src, dst) + + +async def _merge_pytree_directory( + pytree_src: epath.Path, + pytree_dst: epath.Path, + uuid_str: str, +): + """Merges a single pending pytree directory into the destination.""" + if not await async_path.exists(pytree_src): + return + + async def _merge_item(item_path: epath.Path): + if item_path.name in [format_utils.PYTREE_METADATA_FILE, '_sharding']: + await _merge_pytree_metadata(item_path, pytree_dst / item_path.name) + elif item_path.name.startswith('ocdbt.process_'): + await _rename_ocdbt_process_dir(item_path, pytree_dst, uuid_str) + elif item_path.name == 'array_metadatas': + await _merge_array_metadatas(item_path, pytree_dst / item_path.name) + else: + await _recursive_merge(item_path, pytree_dst / item_path.name) + + await asyncio.gather( + *[_merge_item(item) for item in await async_path.iterdir(pytree_src)] + ) + + await async_path.rmtree(pytree_src) + + +async def _merge_checkpoint_metadata(src: epath.Path, dst: epath.Path): + """Merges checkpoint metadata.""" + + def _merge_fn(src_meta, dst_meta): + return tree_structure_utils.merge_trees(dst_meta, src_meta, overwrite=True) + + await _rename_or_merge_json(src, dst, _merge_fn) + + +async def _merge_all(partial_path: epath.Path): + """Merges all pending directories into the partial path.""" + + # Each partial save call results in a new pending directory containing unique + # PyTree keypaths and corresponding data. During finalization, all pending + # directories are merged to form the final checkpoint state. + # Ensure deterministic merge order (alphabetical glob + timestamp). + pending_dirs = sorted(await snapshot.list_pending_dirs(partial_path)) + + first_metadata = await _read_first_metadata(pending_dirs) + use_zarr3 = first_metadata.use_zarr3 if first_metadata is not None else False + + pytree_directories = set() + + for p_dir in pending_dirs: + uuid_str = snapshot.get_uuid_from_pending_dir_name(p_dir.name) + + async def _process_item(item: epath.Path, uuid_str: str): + if item.name == CHECKPOINT_METADATA_FILENAME: + await _merge_checkpoint_metadata(item, partial_path / item.name) + elif await async_path.is_dir(item) and await async_path.exists( + item / format_utils.PYTREE_METADATA_FILE + ): + pytree_directories.add(item.name) + pytree_dst = partial_path / item.name + await async_path.mkdir(pytree_dst, parents=True, exist_ok=True) + await _merge_pytree_directory(item, pytree_dst, uuid_str) + else: + await _recursive_merge(item, partial_path / item.name) + + await asyncio.gather(*[ + _process_item(item, uuid_str) + for item in await async_path.iterdir(p_dir) + ]) + + await async_path.rmtree(p_dir) + + # 3. Call PyTreeHandler.finalize to perform OCDBT merge. + # This merges the individual ocdbt.process_xxx directories into a single + # valid manifest for the final partial state. + handler = base_pytree_checkpoint_handler.BasePyTreeCheckpointHandler( + use_zarr3=use_zarr3 + ) + for pytree_dir_name in pytree_directories: + await asyncio.to_thread(handler.finalize, partial_path / pytree_dir_name) + + def finalize(path: path_types.PathLike) -> None: """Finalizes a partially-saved checkpoint, making it permanent and readable. @@ -302,17 +566,18 @@ async def _finalize_impl(): processes=context.multiprocessing_options.active_processes, ) - rename_failed = False - rename_error = None + finalize_failed = False + finalize_error = None if multihost.is_primary_host(context.multiprocessing_options.primary_host): try: + await _merge_all(partial_path) await async_path.rename(partial_path, final_path) - except OSError as e: - rename_failed = True - rename_error = e + except (ValueError, OSError) as e: + finalize_failed = True + finalize_error = e - rename_failed = multihost.broadcast_one_to_all( - rename_failed, + finalize_failed = multihost.broadcast_one_to_all( + finalize_failed, is_source=multihost.is_primary_host( context.multiprocessing_options.primary_host ), @@ -327,9 +592,7 @@ async def _finalize_impl(): processes=context.multiprocessing_options.active_processes, ) - if rename_failed: - if rename_error is not None: - raise rename_error - raise OSError('Partial checkpoint finalization failed during rename.') + if finalize_failed: + raise finalize_error or OSError('Partial checkpoint finalization failed.') asyncio_utils.run_sync(_finalize_impl()) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py index b5fcdea9e..46bdaf4b7 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/saving/execution.py @@ -85,7 +85,7 @@ def add_internal_checkpointables( class _SaveResponse(AsyncResponse[None]): - """An :py:class:`.AsyncResponse` representing the result of :py:func:`.save_async`.""" + """:py:class:`.AsyncResponse`, result of :py:func:`.save_async`.""" def __init__( self, @@ -389,17 +389,15 @@ def save_checkpointables_impl( path = context.file_options.path_class(path) _check_directory_consistency(path) - path_exists = path.exists() if partial_save else False # Prevent internal mutation from affecting the caller. checkpointables = dict(checkpointables) checkpointables = add_internal_checkpointables( checkpointables, context=context ) - subdirectories = [] if path_exists else checkpointables.keys() - snapshot_type = snapshot_lib.SnapshotType.IN_PLACE if path_exists else None + snapshot_type = snapshot_lib.SnapshotType.EMPTY if partial_save else None temporary_path = _TemporaryPathAwaitingCreation( path, - subdirectories=subdirectories, + subdirectories=checkpointables.keys(), snapshot_type=snapshot_type, ) background_awaitable = asyncio_utils.run_sync(