diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py similarity index 80% rename from checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test.py rename to checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py index 41f64811a..54f601de9 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py @@ -26,6 +26,7 @@ from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.handlers import registration from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout as checkpoint_layout_lib +from orbax.checkpoint.experimental.v1._src.synchronization import multihost from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils @@ -36,7 +37,8 @@ _BASE_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') -class CheckpointablesMetadataCompatibilityTest(parameterized.TestCase): +class CheckpointablesMetadataCompatibilityTestBase(parameterized.TestCase): + """Tests for V1 checkpointables_metadata API against generated Checkpoints.""" def setUp(self) -> None: super().setUp() @@ -94,6 +96,14 @@ def test_checkpointables_metadata_compatibility( is_direct_checkpoint: bool, is_pytree: bool, ) -> None: + """Tests checkpointables_metadata against various checkpoint formats. + + Args: + version: v0 or v1. + metadata_present: Whether the checkpoint has metadata files. + is_direct_checkpoint: Whether the checkpoint is a direct checkpoint. + is_pytree: Whether the checkpoint is a pytree checkpoint. + """ path = compatibility_test_utils.get_checkpoint_path( version, metadata_present, is_direct_checkpoint, is_pytree ) @@ -133,7 +143,11 @@ def test_checkpointables_metadata_compatibility( else: expected = self.expected_checkpointables_metadata - test_utils.assert_tree_equal(self, expected, loaded.metadata) + actual = loaded.metadata + if multihost.is_pathways_backend() or jax.process_count() > 1: + expected = compatibility_test_utils.strip_sharding_metadata(expected) + actual = compatibility_test_utils.strip_sharding_metadata(actual) + test_utils.assert_tree_equal(self, expected, actual) else: with self.assertRaisesRegex(error_type, expected_error_msg): ocp.checkpointables_metadata(path) @@ -153,6 +167,12 @@ def test_checkpointables_metadata_compatibility( def test_checkpointables_metadata_non_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests checkpointables_metadata against non-critical corruptions. + + Args: + version: The checkpoint version to test against. + alteration: The alteration to apply to the checkpoint. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', @@ -162,9 +182,12 @@ def test_checkpointables_metadata_non_critical_corruptions( # Missing sharding metadata results in a pytree identical to expected # values except sharding metadata is None. loaded = ocp.checkpointables_metadata(path) - test_utils.assert_tree_equal( - self, self.expected_checkpointables_metadata, loaded.metadata - ) + expected = self.expected_checkpointables_metadata + actual = loaded.metadata + if multihost.is_pathways_backend() or jax.process_count() > 1: + expected = compatibility_test_utils.strip_sharding_metadata(expected) + actual = compatibility_test_utils.strip_sharding_metadata(actual) + test_utils.assert_tree_equal(self, expected, actual) @parameterized.product( version=['v0', 'v1'], @@ -172,6 +195,11 @@ def test_checkpointables_metadata_non_critical_corruptions( def test_checkpointables_metadata_missing_sharding_corruption( self, version: str ) -> None: + """Tests checkpointables_metadata against missing sharding corruption. + + Args: + version: The checkpoint version to test against. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', @@ -193,6 +221,12 @@ def test_checkpointables_metadata_missing_sharding_corruption( def test_checkpointables_metadata_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests checkpointables_metadata against critical corruptions. + + Args: + version: The checkpoint version to test against. + alteration: The alteration to apply to the checkpoint. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/generate_v0_checkpoint_manager_checkpoints.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/generate_v0_checkpoint_manager_checkpoints.py new file mode 100644 index 000000000..9f5d9821e --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/generate_v0_checkpoint_manager_checkpoints.py @@ -0,0 +1,145 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generates V0 CheckpointManager checkpoints for compatibility testing. + +The checkpoints generated by this script are checked into the repository +statically to ensure long-term backward compatibility against runtime changes. +""" +import itertools +import os +from typing import Any + +from absl import app +from absl import flags +from etils import epath +import jax +import jax.numpy as jnp +from orbax.checkpoint import args +from orbax.checkpoint import checkpoint_manager +from orbax.checkpoint import handlers +from orbax.checkpoint import test_utils +from orbax.checkpoint.experimental.v1._src.testing import handler_utils + +FLAGS = flags.FLAGS + + +def _get_base_dir(): + if 'BUILD_WORKING_DIRECTORY' in os.environ: + return os.path.join( + os.environ['BUILD_WORKING_DIRECTORY'], + 'orbax/checkpoint/experimental/v1/_src/testing/compatibility/managed_checkpoints', + ) + return os.path.join( + os.path.dirname(__file__), + 'managed_checkpoints', + ) + + +_BASE_DIR = flags.DEFINE_string( + 'base_dir', + _get_base_dir(), + 'Base directory to save checkpoints.', +) +_OVERWRITE = flags.DEFINE_bool( + 'overwrite', + False, + 'Overwrite existing checkpoints.', +) + + +def create_pytree() -> dict[str, Any]: + return { + 'a': jax.device_put(jnp.array([0, 1, 2, 3, 4, 5, 6, 7], dtype=jnp.int32)), + 'b': {'c': jax.device_put(jnp.array([1, 2, 3], dtype=jnp.int32))}, + } + + +def generate_v0_checkpoint_manager_checkpoint( + path: epath.Path, + has_metrics: bool = True, + has_custom_metadata: bool = True, +) -> None: + """Saves a composite manager checkpoint using V0 CheckpointManager.""" + if _OVERWRITE.value: + path.rmtree(missing_ok=True) + path.mkdir(parents=True, exist_ok=True) + + pytree = create_pytree() + json_object = {'metadata': 'json_data'} + baz = handler_utils.Baz(123, 'hi') + registry = handlers.create_default_handler_registry( + state=handlers.PyTreeCheckpointHandler(), + metadata=handlers.JsonCheckpointHandler(), + baz=handler_utils.DataclassCheckpointHandler(), + ) + options = checkpoint_manager.CheckpointManagerOptions( + create=True, + enable_async_checkpointing=False, + prevent_write_metrics=not has_metrics, + best_fn=lambda metrics: metrics['loss'] if has_metrics else None, + ) + + manager = checkpoint_manager.CheckpointManager( + path, + options=options, + handler_registry=registry, + metadata={'foo': 'bar'} if has_custom_metadata else None, + ) + with manager: + manager.save( + 0, + args=args.Composite( + state=args.PyTreeSave(pytree), + metadata=args.JsonSave(json_object), + baz=handler_utils.DataclassSaveArgs(baz), + ), + metrics={'loss': 0.5} if has_metrics else None, + custom_metadata={'custom': 'meta'} if has_custom_metadata else None, + ) + manager.wait_until_finished() + + +def main(argv): + del argv + epath.Path(_BASE_DIR.value).mkdir(parents=True, exist_ok=True) + + test_utils.set_tensorstore_driver_for_test() + + print('Generating V0 CheckpointManager Checkpoints...') + + for has_metrics, has_custom_metadata in itertools.product( + [True, False], repeat=2 + ): + metrics_str = 'has_metrics' if has_metrics else 'no_metrics' + meta_str = ( + 'has_custom_metadata' if has_custom_metadata else 'no_custom_metadata' + ) + path = ( + epath.Path(_BASE_DIR.value) + / 'v0_managed' + / metrics_str + / meta_str + ) + generate_v0_checkpoint_manager_checkpoint( + path, + has_metrics=has_metrics, + has_custom_metadata=has_custom_metadata, + ) + + print(f'V0 CheckpointManager Checkpoints generated at {_BASE_DIR.value}') + + +if __name__ == '__main__': + app.run(main) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/generate_v1_checkpointer_checkpoints.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/generate_v1_checkpointer_checkpoints.py new file mode 100644 index 000000000..68a515b21 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/generate_v1_checkpointer_checkpoints.py @@ -0,0 +1,140 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generates V1 Checkpointer checkpoints for compatibility testing. + +The checkpoints generated by this script are checked into the repository +statically to ensure long-term backward compatibility against runtime changes. +""" +import itertools +import os +from typing import Any + +from absl import app +from absl import flags +from etils import epath +import jax +import jax.numpy as jnp +from orbax.checkpoint import test_utils +import orbax.checkpoint.experimental.v1 as ocp +from orbax.checkpoint.experimental.v1._src.testing import handler_utils + +FLAGS = flags.FLAGS + + +def _get_base_dir(): + if 'BUILD_WORKING_DIRECTORY' in os.environ: + return os.path.join( + os.environ['BUILD_WORKING_DIRECTORY'], + 'orbax/checkpoint/experimental/v1/_src/testing/compatibility/managed_checkpoints', + ) + return os.path.join( + os.path.dirname(__file__), + 'managed_checkpoints', + ) + + +_BASE_DIR = flags.DEFINE_string( + 'base_dir', + _get_base_dir(), + 'Base directory to save checkpoints.', +) +_OVERWRITE = flags.DEFINE_bool( + 'overwrite', + False, + 'Overwrite existing checkpoints.', +) + + +def create_pytree() -> dict[str, Any]: + return { + 'a': jax.device_put(jnp.array([0, 1, 2, 3, 4, 5, 6, 7], dtype=jnp.int32)), + 'b': {'c': jax.device_put(jnp.array([1, 2, 3], dtype=jnp.int32))}, + } + + +def generate_v1_checkpointer_checkpoint( + path: epath.Path, + has_metrics: bool = True, + has_custom_metadata: bool = True, +) -> None: + """Saves a composite checkpoint using V1 Checkpointer.""" + if _OVERWRITE.value: + path.rmtree(missing_ok=True) + path.mkdir(parents=True, exist_ok=True) + + pytree = create_pytree() + json_object = {'metadata': 'json_data'} + baz = handler_utils.Baz(123, 'hi') + + registry = ocp.handlers.local_registry() + registry.add(ocp.handlers.PyTreeHandler, checkpointable_name='state') + registry.add(ocp.handlers.JsonHandler, checkpointable_name='metadata') + registry.add(handler_utils.BazHandler, checkpointable_name='baz') + + checkpointables = { + 'state': pytree, + 'metadata': json_object, + 'baz': baz, + } + + with ocp.Context( + checkpointables_options=ocp.options.CheckpointablesOptions( + registry=registry + ) + ): + checkpointer = ocp.training.Checkpointer( + path, + custom_metadata={'foo': 'bar'} if has_custom_metadata else None, + ) + checkpointer.save_checkpointables( + 0, + checkpointables, + metrics={'loss': 0.5} if has_metrics else None, + custom_metadata={'custom': 'meta'} if has_custom_metadata else None, + ) + + +def main(argv): + del argv + epath.Path(_BASE_DIR.value).mkdir(parents=True, exist_ok=True) + + test_utils.set_tensorstore_driver_for_test() + + print('Generating V1 Checkpointer Checkpoints...') + + for has_metrics, has_custom_metadata in itertools.product( + [True, False], repeat=2 + ): + metrics_str = 'has_metrics' if has_metrics else 'no_metrics' + meta_str = ( + 'has_custom_metadata' if has_custom_metadata else 'no_custom_metadata' + ) + path = ( + epath.Path(_BASE_DIR.value) + / 'v1_managed' + / metrics_str + / meta_str + ) + generate_v1_checkpointer_checkpoint( + path, + has_metrics=has_metrics, + has_custom_metadata=has_custom_metadata, + ) + + print(f'V1 Checkpointer Checkpoints generated at {_BASE_DIR.value}') + + +if __name__ == '__main__': + app.run(main) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py similarity index 83% rename from checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test.py rename to checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py index 2868df6c8..ffe275ea0 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py @@ -20,12 +20,14 @@ from etils import epath import jax import jax.numpy as jnp +import numpy as np from orbax.checkpoint import test_utils from orbax.checkpoint._src.sharding_utils import make_single_device_sharding import orbax.checkpoint.experimental.v1 as ocp from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.handlers import registration from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout as checkpoint_layout_lib +from orbax.checkpoint.experimental.v1._src.synchronization import multihost from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils @@ -36,7 +38,8 @@ _BASE_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') -class LoadCheckpointablesCompatibilityTest(parameterized.TestCase): +class LoadCheckpointablesCompatibilityTestBase(parameterized.TestCase): + """Tests for V1 load_checkpointables API against generated Checkpoints.""" def setUp(self) -> None: super().setUp() @@ -46,6 +49,12 @@ def setUp(self) -> None: 'b': {'c': jnp.array([1, 2, 3], dtype=jnp.int32)}, } sharding = make_single_device_sharding(jax.devices()[0]) + if multihost.is_pathways_backend() or jax.process_count() > 1: + self.expected_state = compatibility_test_utils.replicate_on_mesh( + self.expected_state + ) + mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) self.abstract_state = jax.tree.map( lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=sharding), self.expected_state @@ -177,12 +186,32 @@ def test_load_checkpointables_compatibility( has_pytree_metadata: bool, handler_registered: bool, ) -> None: + """Tests load_checkpointables compatibility with v0 and v1 checkpoints. + + Args: + version: The version of the checkpoint to load. + checkpointable_names_provided: Whether to provide checkpointable_names to + load_checkpointables. + abstract_checkpointables_provided: Whether to provide + abstract_checkpointables to ocp.load_checkpointables. + names_registered: Whether to register all checkpointable names to a + handler. + metadata_present: Whether the checkpoint has metadata files. + is_direct_checkpoint: Whether the checkpoint is a direct checkpoint. + has_pytree_metadata: Whether the checkpoint has pytree metadata files. + handler_registered: Whether to register all handlers. + """ path = compatibility_test_utils.get_checkpoint_path( version, metadata_present, is_direct_checkpoint, has_pytree_metadata ) if path is None or not path.exists(): self.skipTest('Checkpoint for combination does not exist.') + if ( + multihost.is_pathways_backend() or jax.process_count() > 1 + ) and not abstract_checkpointables_provided: + self.skipTest('Sharding metadata not matching in Pathways/Multiprocess.') + if not checkpointable_names_provided and abstract_checkpointables_provided: self.skipTest( 'Cannot provide abstract_checkpointables without' @@ -252,6 +281,12 @@ def test_load_checkpointables_compatibility( def test_load_checkpointables_non_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests load_checkpointables with non-critical corruptions. + + Args: + version: The checkpoint version to test against. + alteration: The alteration to apply to the checkpoint. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', @@ -261,7 +296,9 @@ def test_load_checkpointables_non_critical_corruptions( loaded = ocp.load_checkpointables( path, abstract_checkpointables=self.abstract_checkpointables ) - test_utils.assert_tree_equal(self, loaded, self.expected_checkpointables) + test_utils.assert_tree_equal( + self, loaded, self.expected_checkpointables + ) @parameterized.product( version=['v0', 'v1'], @@ -273,6 +310,12 @@ def test_load_checkpointables_non_critical_corruptions( def test_load_checkpointables_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests load_checkpointables with critical corruptions. + + Args: + version: The checkpoint version to test against. + alteration: The alteration to apply to the checkpoint. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py similarity index 83% rename from checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test.py rename to checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py index 5d82a33f5..c9a63def2 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for V1 load_pytree API against generated V0 and V1 Checkpoints.""" + import os from typing import Tuple, Type @@ -21,12 +22,14 @@ from etils import epath import jax import jax.numpy as jnp +import numpy as np from orbax.checkpoint import test_utils from orbax.checkpoint._src.sharding_utils import make_single_device_sharding import orbax.checkpoint.experimental.v1 as ocp from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.handlers import registration from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout as checkpoint_layout_lib +from orbax.checkpoint.experimental.v1._src.synchronization import multihost from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils @@ -37,7 +40,8 @@ _BASE_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') -class LoadPytreeCompatibilityTest(parameterized.TestCase): +class LoadPytreeCompatibilityTestBase(parameterized.TestCase): + """Tests for V1 load_pytree API against generated Checkpoints.""" def setUp(self) -> None: super().setUp() @@ -47,6 +51,12 @@ def setUp(self) -> None: 'b': {'c': jnp.array([1, 2, 3], dtype=jnp.int32)}, } sharding = make_single_device_sharding(jax.devices()[0]) + if multihost.is_pathways_backend() or jax.process_count() > 1: + self.expected_state = compatibility_test_utils.replicate_on_mesh( + self.expected_state + ) + mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) self.abstract_state = jax.tree.map( lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=sharding), self.expected_state @@ -194,12 +204,33 @@ def test_load_pytree_compatibility( handler_registered: bool, pytree_registered: bool, ) -> None: + """Tests load_pytree against various checkpoint configurations. + + Args: + version: The checkpoint version to test against. + checkpointable_name: The name of the checkpointable to load. + abstract_pytree_provided: Whether an abstract pytree is provided to + ocp.load_pytree. + name_registered: Whether a handler is registered for the + checkpointable_name. + metadata_present: Whether the checkpoint has metadata. + is_direct_checkpoint: Whether the checkpoint is a direct checkpoint. + is_pytree: Whether the checkpoint is a pytree checkpoint. + handler_registered: Whether a handler is registered for the handler + typestr derived from checkpoint metadata. + pytree_registered: Whether a handler is registered for the 'pytree' scope. + """ path = compatibility_test_utils.get_checkpoint_path( version, metadata_present, is_direct_checkpoint, is_pytree ) if path is None or not path.exists(): self.skipTest('Checkpoint for combination does not exist.') + if ( + multihost.is_pathways_backend() or jax.process_count() > 1 + ) and not abstract_pytree_provided: + self.skipTest('Sharding metadata not matching in Pathways/Multiprocess.') + registry = self.setup_registry( path, checkpointable_name, @@ -262,7 +293,12 @@ def test_load_pytree_compatibility( def test_load_pytree_non_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests load_pytree against checkpoints with non-critical corruptions. + Args: + version: The version of the checkpoint to load. + alteration: The non-critical corruption alteration to apply. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', @@ -284,6 +320,12 @@ def test_load_pytree_non_critical_corruptions( def test_load_pytree_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests load_pytree against checkpoints with critical corruptions. + + Args: + version: The version of the checkpoint to load. + alteration: The critical corruption alteration to apply. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', @@ -303,6 +345,11 @@ def test_load_pytree_critical_corruptions( version=['v0', 'v1'], ) def test_load_incorrect_path(self, version: str) -> None: + """Tests load_pytree against checkpoints with incorrect paths. + + Args: + version: The version of the checkpoint to test against. + """ checkpoint_path = ( self.base_dir / f'{version}_checkpoints' diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py similarity index 81% rename from checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test.py rename to checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py index 6cfa1f4f4..986c4633e 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py @@ -27,9 +27,9 @@ from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.handlers import registration from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout as checkpoint_layout_lib +from orbax.checkpoint.experimental.v1._src.synchronization import multihost from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils - CheckpointLayoutEnum = options_lib.CheckpointLayout InvalidLayoutError = checkpoint_layout_lib.InvalidLayoutError @@ -37,7 +37,8 @@ _BASE_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') -class PytreeMetadataCompatibilityTest(parameterized.TestCase): +class PytreeMetadataCompatibilityTestBase(parameterized.TestCase): + """Tests for V1 pytree_metadata API against generated Checkpoints.""" def setUp(self) -> None: super().setUp() @@ -155,6 +156,17 @@ def test_pytree_metadata_compatibility( is_pytree: bool, handler_registered: bool, ) -> None: + """Tests pytree_metadata compatibility across V0 and V1 checkpoints. + + Args: + version: The checkpoint version to test against. + checkpointable_name: The name of the checkpointable to load. + name_registered: Whether the checkpointable name is registered. + metadata_present: Whether the checkpoint metadata file is present. + is_direct_checkpoint: Whether the checkpoint is a direct checkpoint. + is_pytree: Whether the checkpointable is a pytree. + handler_registered: Whether the handler is registered. + """ path = compatibility_test_utils.get_checkpoint_path( version, metadata_present, is_direct_checkpoint, is_pytree ) @@ -184,9 +196,12 @@ def test_pytree_metadata_compatibility( path, checkpointable_name=checkpointable_name, ) - test_utils.assert_tree_equal( - self, self.expected_state_metadata, loaded.metadata - ) + expected = self.expected_state_metadata + actual = loaded.metadata + if multihost.is_pathways_backend() or jax.process_count() > 1: + expected = compatibility_test_utils.strip_sharding_metadata(expected) + actual = compatibility_test_utils.strip_sharding_metadata(actual) + test_utils.assert_tree_equal(self, expected, actual) else: with self.assertRaisesRegex(error_type, error_msg): ocp.pytree_metadata( @@ -209,6 +224,12 @@ def test_pytree_metadata_compatibility( def test_pytree_metadata_non_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests pytree_metadata with non-critical corruptions. + + Args: + version: The checkpoint version to test against. + alteration: The alteration to apply to the checkpoint. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', @@ -216,9 +237,12 @@ def test_pytree_metadata_non_critical_corruptions( alteration, ) loaded = ocp.pytree_metadata(path, checkpointable_name='state') - test_utils.assert_tree_equal( - self, self.expected_state_metadata, loaded.metadata - ) + expected = self.expected_state_metadata + actual = loaded.metadata + if multihost.is_pathways_backend() or jax.process_count() > 1: + expected = compatibility_test_utils.strip_sharding_metadata(expected) + actual = compatibility_test_utils.strip_sharding_metadata(actual) + test_utils.assert_tree_equal(self, expected, actual) @parameterized.product( version=['v0', 'v1'], @@ -226,6 +250,11 @@ def test_pytree_metadata_non_critical_corruptions( def test_pytree_metadata_missing_sharding_corruption( self, version: str ) -> None: + """Tests pytree_metadata with missing sharding corruption. + + Args: + version: The checkpoint version to test against. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', @@ -247,6 +276,12 @@ def test_pytree_metadata_missing_sharding_corruption( def test_pytree_metadata_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests pytree_metadata with critical corruptions. + + Args: + version: The checkpoint version to test against. + alteration: The alteration to apply to the checkpoint. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/test_utils.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/test_utils.py index 9b3c7eba6..80d2d4581 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/test_utils.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/test_utils.py @@ -82,3 +82,31 @@ def create_value_metadata(value: Any) -> Any: return 0.0 else: raise TypeError(f'Unsupported type: {type(value)}') + + +def strip_sharding_metadata(tree: Any) -> Any: + """Strips concrete sharding_metadata from Metadata to decouple from topologies.""" + def _strip(x): + return array_leaf_handler.ArrayMetadata( + shape=x.shape, + dtype=x.dtype, + sharding_metadata=None, + storage_metadata=x.storage_metadata, + ) + return jax.tree.map( + _strip, + tree, + is_leaf=lambda leaf: hasattr(leaf, 'sharding_metadata'), + ) + + +def replicate_on_mesh(tree: Any) -> Any: + """Replicates a PyTree of arrays across all devices in the current mesh.""" + mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + return jax.tree.map( + lambda x: jax.device_put(x, sharding) + if isinstance(x, (jax.Array, np.ndarray)) + else x, + tree, + ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/v0v1_compatibility_save_load_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/v0v1_compatibility_save_load_test_base.py deleted file mode 100644 index 6ea0bfb66..000000000 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/v0v1_compatibility_save_load_test_base.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright 2026 The Orbax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Base class for v0/v1 compatibility save/load tests.""" - -# pylint: disable=missing-class-docstring,protected-access,missing-function-docstring - -from __future__ import annotations - -from absl.testing import parameterized -from etils import epath -from orbax.checkpoint import args as args_lib -from orbax.checkpoint import test_utils -from orbax.checkpoint._src.checkpointers import checkpointer as v0_checkpointer -from orbax.checkpoint._src.checkpointers import standard_checkpointer -from orbax.checkpoint._src.handlers import composite_checkpoint_handler -import orbax.checkpoint.experimental.v1 as ocp -from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout -from orbax.checkpoint.experimental.v1._src.path import types as path_types -from orbax.checkpoint.experimental.v1._src.synchronization import multihost -from orbax.checkpoint.experimental.v1._src.testing import array_utils as array_test_utils -from orbax.checkpoint.experimental.v1._src.tree import types as tree_types - -PyTree = tree_types.PyTree -Path = path_types.Path -InvalidLayoutError = checkpoint_layout.InvalidLayoutError - -create_sharded_pytree = array_test_utils.create_sharded_pytree - - -class CompatibilitySaveLoadTestBase: - - class Test(parameterized.TestCase): - - def setUp(self): - super().setUp() - - self.root_directory = epath.Path( - self.create_tempdir(name='root').full_path - ) - self.ckpt_directory = ( - epath.Path(self.create_tempdir(name='direct').full_path) / 'ckpt' - ) - self.pytree, self.abstract_pytree = create_sharded_pytree() - - test_utils.set_tensorstore_driver_for_test() - test_utils.sync_global_processes('CompatibilityTest:setup_complete') - - def tearDown(self): - super().tearDown() - test_utils.sync_global_processes('CompatibilityTest:teardown_complete') - - def save_v0_checkpoint(self, directory: Path): - with standard_checkpointer.StandardCheckpointer() as checkpointer: - checkpointer.save(directory, self.pytree) - - def save_v0_checkpoints( - self, base_dir: Path, *, checkpointable_names: list[str] - ): - args = args_lib.Composite(**{ - checkpointable_name: args_lib.StandardSave(self.pytree) - for checkpointable_name in checkpointable_names - }) - with v0_checkpointer.Checkpointer( - composite_checkpoint_handler.CompositeCheckpointHandler() - ) as checkpointer: - checkpointer.save(base_dir, args) - - def test_async_load(self): - with self.assertRaises(NotImplementedError): - ocp.load_pytree_async( - self.root_directory, - ) - with self.assertRaises(NotImplementedError): - ocp.load_checkpointables_async( - self.root_directory, - ) - - @parameterized.product( - with_abstract_pytree=[True, False], - ) - def test_load_v0_checkpoint_with_v1_load_pytree( - self, with_abstract_pytree: bool - ): - - checkpointable_names = ['default', 'state', 'pytree'] - step_dir = self.root_directory / 'load_pytree_0' - self.save_v0_checkpoints( - step_dir, - checkpointable_names=checkpointable_names, - ) - self.save_v0_checkpoint(self.ckpt_directory) - - with self.subTest('no_checkpointable_name'): - loaded = ocp.load_pytree( - step_dir, - self.abstract_pytree if with_abstract_pytree else None, - ) - test_utils.assert_tree_equal(self, self.pytree, loaded) - - with self.subTest('flat_layout_no_checkpointable_name'): - loaded = ocp.load_pytree( - self.ckpt_directory, - self.abstract_pytree if with_abstract_pytree else None, - ) - test_utils.assert_tree_equal(self, self.pytree, loaded) - - with self.subTest('root_path_no_checkpointable_name_error'): - with self.assertRaises(InvalidLayoutError): - ocp.load_pytree( - self.root_directory, - self.abstract_pytree if with_abstract_pytree else None, - ) - - for checkpointable_name in checkpointable_names: - with self.subTest(f'pass_{checkpointable_name}'): - loaded = ocp.load_pytree( - step_dir, - self.abstract_pytree if with_abstract_pytree else None, - checkpointable_name=checkpointable_name, - ) - test_utils.assert_tree_equal(self, self.pytree, loaded) - - with self.subTest(f'pass_{checkpointable_name}_error'): - with self.assertRaises(InvalidLayoutError): - ocp.load_pytree( - self.ckpt_directory, - self.abstract_pytree if with_abstract_pytree else None, - checkpointable_name=checkpointable_name, - ) - with self.assertRaises(InvalidLayoutError): - ocp.load_pytree( - self.root_directory, - self.abstract_pytree if with_abstract_pytree else None, - checkpointable_name=checkpointable_name, - ) - - with self.subTest('pass_none'): - loaded = ocp.load_pytree( - self.ckpt_directory, - self.abstract_pytree if with_abstract_pytree else None, - checkpointable_name=None, - ) - test_utils.assert_tree_equal(self, self.pytree, loaded) - with self.assertRaises(InvalidLayoutError): - ocp.load_pytree( - step_dir, - self.abstract_pytree if with_abstract_pytree else None, - checkpointable_name=None, - ) - with self.assertRaises(InvalidLayoutError): - ocp.load_pytree( - self.root_directory, - self.abstract_pytree if with_abstract_pytree else None, - checkpointable_name=None, - ) - - def test_load_v0_checkpoint_with_v1_pytree_metadata(self): - checkpointable_names = ['default', 'state', 'pytree'] - step_dir = self.root_directory / 'load_pytree_0' - self.save_v0_checkpoints( - step_dir, - checkpointable_names=checkpointable_names, - ) - self.save_v0_checkpoint(self.ckpt_directory) - - with self.subTest('no_checkpointable_name'): - loaded = ocp.pytree_metadata(step_dir) - test_utils.assert_tree_same_structure( - self, self.abstract_pytree, loaded.metadata - ) - - with self.subTest('flat_layout_no_checkpointable_name'): - metadata = ocp.pytree_metadata(self.ckpt_directory) - test_utils.assert_tree_same_structure( - self, self.abstract_pytree, metadata.metadata - ) - - with self.subTest('root_path_no_checkpointable_name_error'): - with self.assertRaises(InvalidLayoutError): - ocp.pytree_metadata(self.root_directory) - - for checkpointable_name in checkpointable_names: - with self.subTest(f'pass_{checkpointable_name}'): - loaded = ocp.pytree_metadata( - step_dir, - checkpointable_name=checkpointable_name, - ) - test_utils.assert_tree_same_structure( - self, self.abstract_pytree, loaded.metadata - ) - - with self.subTest(f'pass_{checkpointable_name}_error'): - with self.assertRaises(InvalidLayoutError): - ocp.pytree_metadata( - self.ckpt_directory, - checkpointable_name=checkpointable_name, - ) - with self.assertRaises(InvalidLayoutError): - ocp.pytree_metadata( - self.root_directory, - checkpointable_name=checkpointable_name, - ) - - with self.subTest('pass_none'): - loaded = ocp.pytree_metadata( - self.ckpt_directory, - checkpointable_name=None, - ) - test_utils.assert_tree_same_structure( - self, self.abstract_pytree, loaded.metadata - ) - with self.assertRaises(InvalidLayoutError): - ocp.pytree_metadata( - step_dir, - checkpointable_name=None, - ) - with self.assertRaises(InvalidLayoutError): - ocp.pytree_metadata( - self.root_directory, - checkpointable_name=None, - ) - - @parameterized.product( - checkpointable_name=['default', 'state'], - with_abstract_pytree=[True, False], - ) - def test_load_v0_checkpoint_with_v1_load_checkpointables( - self, - checkpointable_name: str, - with_abstract_pytree: bool, - ): - - checkpointable_names = ['default', 'state'] - step_dir = self.root_directory / 'load_checkpointables_0' - self.save_v0_checkpoints( - step_dir, - checkpointable_names=checkpointable_names, - ) - self.save_v0_checkpoint(self.ckpt_directory) - - abstract_checkpointables = ( - {checkpointable_name: self.abstract_pytree} - ) - - with self.subTest('with_context'): - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - **{checkpointable_name: ocp.handlers.PyTreeHandler} - ) - ) - with ocp.Context(checkpointables_options=checkpointables_options): - loaded = ocp.load_checkpointables(step_dir, abstract_checkpointables) - test_utils.assert_tree_equal( - self, self.pytree, loaded[checkpointable_name] - ) - - with self.subTest('without_context'): - loaded = ocp.load_checkpointables(step_dir, abstract_checkpointables) - test_utils.assert_tree_equal( - self, self.pytree, loaded[checkpointable_name] - ) - # TODO(b/484400394): Find a better way to inform the user that they need - # to use load_pytree(..., checkpointable_name=None) when item_handlers is - # a str. - with self.subTest('error_with_checkpoint_path'): - with self.assertRaisesRegex( - KeyError, 'Requested checkpointables:' - ): - ocp.load_checkpointables( - self.ckpt_directory, abstract_checkpointables - ) - with self.subTest('error_with_root_path'): - with self.assertRaises(InvalidLayoutError): - ocp.load_checkpointables( - self.root_directory, abstract_checkpointables - ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/v1_compatibility_load_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/v1_compatibility_load_test.py new file mode 100644 index 000000000..b4c2409e8 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/v1_compatibility_load_test.py @@ -0,0 +1,59 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl import flags +from absl.testing import absltest +from absl.testing import parameterized +import jax +from orbax.checkpoint.experimental.v1._src.testing.compatibility import checkpointables_metadata_compatibility_test_base +from orbax.checkpoint.experimental.v1._src.testing.compatibility import load_checkpointables_compatibility_test_base +from orbax.checkpoint.experimental.v1._src.testing.compatibility import load_pytree_compatibility_test_base +from orbax.checkpoint.experimental.v1._src.testing.compatibility import pytree_metadata_compatibility_test_base + + +FLAGS = flags.FLAGS + +jax.config.update('jax_enable_x64', True) + + +class CheckpointablesMetadataTest( + checkpointables_metadata_compatibility_test_base.CheckpointablesMetadataCompatibilityTestBase, + parameterized.TestCase, +): + pass + + +class LoadCheckpointablesTest( + load_checkpointables_compatibility_test_base.LoadCheckpointablesCompatibilityTestBase, + parameterized.TestCase, +): + pass + + +class LoadPytreeTest( + load_pytree_compatibility_test_base.LoadPytreeCompatibilityTestBase, + parameterized.TestCase, +): + pass + + +class PytreeMetadataTest( + pytree_metadata_compatibility_test_base.PytreeMetadataCompatibilityTestBase, + parameterized.TestCase, +): + pass + + +if __name__ == '__main__': + absltest.main()