diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/assets.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/assets.py index 9689ea06c..4e03b3e89 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/assets.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/assets.py @@ -20,6 +20,7 @@ """ from collections.abc import Collection, Sequence +import dataclasses import datetime from absl import logging @@ -35,6 +36,24 @@ from google.protobuf import timestamp_pb2 +class DeletionPendingError(ValueError): + """Raised when an operation is attempted on an asset/TierPath marked for deletion.""" + + +@dataclasses.dataclass +class CreatePrefetchJobResult: + """Result of creating a prefetch job. + + Attributes: + asset: The updated asset, or None if not found. + created: True if a new job was successfully created, False if it failed due + to a concurrent insert or there is already an existing job queued. + """ + + asset: db_schema.Asset | None + created: bool + + def _proto_from_db_tier_path( tier_path: db_schema.TierPath, ) -> tiering_service_pb2.TierPath: @@ -92,6 +111,7 @@ def _get_location_kwargs(sb: db_schema.StorageBackend): storage_backend=proto_storage_backend, ready_at=ready_at_pb, expires_at=expires_at_pb, + tier_path_uuid=tier_path.tier_path_uuid, ) @@ -381,3 +401,202 @@ async def finalize_asset( await session.commit() await session.refresh(db_asset, attribute_names=["updated_at"]) return db_asset + + +async def is_delete_pending(session: AsyncSession, asset_uuid: str) -> bool: + """Checks if there is a pending delete job for the asset.""" + stmt = ( + select(db_schema.AssetJob) + .filter_by( + asset_uuid=asset_uuid, + request_type=db_schema.RequestType.REQUEST_TYPE_DELETE_FROM_ALL_TIERS, + ) + .where( + db_schema.AssetJob.status.in_([ + db_schema.JobStatus.JOB_STATUS_QUEUED, + db_schema.JobStatus.JOB_STATUS_PROCESSING, + ]) + ) + ) + result = await session.execute(stmt) + return bool(result.scalars().first()) + + +async def is_tier_path_delete_pending( + session: AsyncSession, + *, + asset_uuid: str, + tier_path_id: int, +) -> bool: + """Checks if there is a pending delete job for the specific TierPath.""" + stmt = ( + select(db_schema.AssetJob) + .filter_by( + asset_uuid=asset_uuid, + target_tier_path_id=tier_path_id, + request_type=db_schema.RequestType.REQUEST_TYPE_DELETE_FROM_INSTANCE, + ) + .where( + db_schema.AssetJob.status.in_([ + db_schema.JobStatus.JOB_STATUS_QUEUED, + db_schema.JobStatus.JOB_STATUS_PROCESSING, + ]) + ) + ) + result = await session.execute(stmt) + return bool(result.scalars().first()) + + +async def create_prefetch_job( + session: AsyncSession, + db_asset: db_schema.Asset, + *, + backend: db_schema.StorageBackend, + storage_path: str, + client_keep_alive_interval: datetime.timedelta, +) -> CreatePrefetchJobResult: + """Queues a prefetch job for the given asset to the target backend. + + This function executes atomically in a single transaction. It creates both + the `TierPath` and the corresponding `AssetJob` together, ensuring that we + never commit a "dangling" `TierPath` without an associated prefetch job. If + either operation fails (e.g. due to concurrent insertion conflicts), the + entire transaction is rolled back. + + Args: + session: The database session (active transaction). + db_asset: The asset to prefetch. + backend: The target storage backend (level 0). + storage_path: The storage path to use for the new TierPath. + client_keep_alive_interval: The interval to set for the initial expires_at + of the TierPath. + + Returns: + A CreatePrefetchJobResult containing the updated asset and a boolean + indicating whether a new job was created. + + Raises: + DeletionPendingError: If the asset is already marked for deletion. + """ + + # Check if there is already a preceding delete job + if await is_delete_pending(session, db_asset.asset_uuid): + raise DeletionPendingError( + f"Cannot prefetch asset {db_asset.asset_uuid} because it is marked for" + " deletion." + ) + + logging.info( + "Prefetch: Creating new pending TierPath and job for asset %s and" + " backend %s", + db_asset.asset_uuid, + backend.id, + ) + new_tp = db_schema.TierPath( + storage_backend=backend, + path=storage_path, + expires_at=calculate_expires_at(client_keep_alive_interval), + ) + db_asset.tier_paths.append(new_tp) + + db_job = db_schema.AssetJob( + asset_uuid=db_asset.asset_uuid, + request_type=db_schema.RequestType.REQUEST_TYPE_COPY, + status=db_schema.JobStatus.JOB_STATUS_QUEUED, + target_tier_path=new_tp, + ) + session.add(db_job) + + asset_uuid = db_asset.asset_uuid + backend_id = backend.id + try: + await session.commit() + except IntegrityError: + await session.rollback() + logging.debug( + "Prefetch: Concurrent insert detected for asset %s and backend %s," + " rolling back", + asset_uuid, + backend_id, + ) + db_assets = await fetch_asset_by_uuid(session, asset_uuid) + return CreatePrefetchJobResult( + asset=(db_assets[0] if db_assets else None), created=False + ) + + await session.refresh(db_asset, attribute_names=["updated_at"]) + return CreatePrefetchJobResult(asset=db_asset, created=True) + + +async def prefetch_keep_alive( + session: AsyncSession, + *, + tier_path_uuid: str, + interval: datetime.timedelta, +) -> db_schema.Asset | None: + """Extend the TierPath's expiration timestamp. + + Args: + session: The database session. + tier_path_uuid: The UUID of the TierPath to update. + interval: The new timeout interval. + + Returns: + The updated Asset object, or None if the TierPath was not found. + + Raises: + DeletionPendingError: If the asset associated with the TierPath is marked + for deletion, or if the specific TierPath instance is marked for + deletion. + """ + stmt = select(db_schema.TierPath).filter_by(tier_path_uuid=tier_path_uuid) + result = await session.execute(stmt) + tp = result.scalars().first() + if tp is None: + return None + + if await is_delete_pending(session, tp.asset_uuid): + raise DeletionPendingError(f"Asset {tp.asset_uuid} is marked for deletion.") + + if await is_tier_path_delete_pending( + session, asset_uuid=tp.asset_uuid, tier_path_id=tp.id + ): + raise DeletionPendingError( + f"TierPath {tier_path_uuid} is marked for deletion." + ) + + if tp.expires_at is None: + logging.debug( + "TierPath %s has no expires_at (still copying or permanent), no-op", + tier_path_uuid, + ) + db_assets = await fetch_asset_by_uuid(session, tp.asset_uuid) + return db_assets[0] if db_assets else None + + new_expires_at = calculate_expires_at(interval) + existing_expires_at = tp.expires_at + compared_expires_at = ( + existing_expires_at.replace(tzinfo=datetime.timezone.utc) + if existing_expires_at.tzinfo is None + else existing_expires_at + ) + if new_expires_at > compared_expires_at: + logging.debug( + "Extending TierPath %s expires_at from %s to %s", + tier_path_uuid, + existing_expires_at, + new_expires_at, + ) + tp.expires_at = new_expires_at + await session.commit() + else: + logging.debug( + "New expires_at %s is not longer than existing %s for TierPath %s," + " no-op", + new_expires_at, + existing_expires_at, + tier_path_uuid, + ) + + db_assets = await fetch_asset_by_uuid(session, tp.asset_uuid) + return db_assets[0] if db_assets else None diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/assets_test.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/assets_test.py index 8c0fbb76d..a1cc0fb87 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/assets_test.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/assets_test.py @@ -16,11 +16,16 @@ import unittest from absl.testing import absltest +from absl.testing import parameterized +import aiosqlite # pylint: disable=unused-import +import greenlet # pylint: disable=unused-import from orbax.checkpoint.experimental.tiering_service import assets from orbax.checkpoint.experimental.tiering_service import db_schema +from orbax.checkpoint.experimental.tiering_service import storage_backend as storage_backend_lib from orbax.checkpoint.experimental.tiering_service.proto import tiering_service_pb2 from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.future import select from sqlalchemy.orm import sessionmaker from google.protobuf import timestamp_pb2 @@ -199,7 +204,7 @@ def test_proto_from_db_asset_backend_with_multi_regions(self): self.assertFalse(sb_proto.HasField("region")) -class AssetsDbTest(absltest.TestCase, unittest.IsolatedAsyncioTestCase): +class AssetsDbTest(parameterized.TestCase, unittest.IsolatedAsyncioTestCase): def _assert_date_time_equal(self, dt1, dt2): if dt1 is None or dt2 is None: @@ -401,6 +406,511 @@ async def test_queries_filtering(self): self.assertLen(fetched_b_by_uuid, 1) self.assertEqual(fetched_b_by_uuid[0].path, "path/B") + async def _set_a_finalized_asset( + self, session: AsyncSession + ) -> tuple[ + db_schema.Asset, db_schema.StorageBackend, db_schema.StorageBackend + ]: + """Sets up a finalized asset in the database. + + Creates two storage backends and one asset. The asset is initially reserved + against one backend and then immediately finalized. + + Args: + session: The SQLAlchemy AsyncSession to use for database operations. + + Returns: + A tuple (asset, b1, b2), where asset is the finalized db_schema.Asset, + b1 is the first db_schema.StorageBackend used for the initial reservation, + and b2 is the second db_schema.StorageBackend. + """ + b1 = db_schema.StorageBackend( + level=0, + backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + prefix="/mnt/lustre-a", + zone="us-central1-a", + ) + b2 = db_schema.StorageBackend( + level=0, + backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + prefix="/mnt/lustre-b", + zone="us-central1-b", + ) + session.add_all([b1, b2]) + await session.commit() + + request = tiering_service_pb2.ReserveRequest( + path="test/path/finalized_asset", + user="test-user", + zone="us-central1-a", + ) + reserved_asset = await assets.create_or_fetch_asset( + session, + request, + b1, + tiering_service_pb2.ServerConfig( + client_keep_alive_interval_seconds=600 + ), + ) + finalized_asset = await assets.finalize_asset(session, reserved_asset) + return finalized_asset, b1, b2 + + async def test_create_prefetch_job_returns_created_and_updated_asset(self): + async with self.session_maker() as session: + asset, _, b2 = await self._set_a_finalized_asset(session) + storage_path = storage_backend_lib.get_storage_path(b2, asset.path) + result = await assets.create_prefetch_job( + session, + asset, + backend=b2, + storage_path=storage_path, + client_keep_alive_interval=datetime.timedelta(seconds=600), + ) + self.assertTrue(result.created) + self.assertIsNotNone(result.asset) + + async def test_create_prefetch_job_updates_tier_paths(self): + async with self.session_maker() as session: + asset, b1, b2 = await self._set_a_finalized_asset(session) + storage_path = storage_backend_lib.get_storage_path(b2, asset.path) + result = await assets.create_prefetch_job( + session, + asset, + backend=b2, + storage_path=storage_path, + client_keep_alive_interval=datetime.timedelta(seconds=600), + ) + updated_asset = result.asset + self.assertIsNotNone(updated_asset) + paths = [tp.path for tp in updated_asset.tier_paths] + self.assertCountEqual( + paths, + [ + storage_backend_lib.get_storage_path(b1, asset.path), + storage_path, + ], + ) + + async def test_create_prefetch_job_db_tier_path_not_ready(self): + async with self.session_maker() as session: + asset, _, b2 = await self._set_a_finalized_asset(session) + storage_path = storage_backend_lib.get_storage_path(b2, asset.path) + await assets.create_prefetch_job( + session, + asset, + backend=b2, + storage_path=storage_path, + client_keep_alive_interval=datetime.timedelta(seconds=600), + ) + + stmt_tp = select(db_schema.TierPath).filter_by( + asset_uuid=asset.asset_uuid, storage_backend_id=b2.id + ) + result_tp = await session.execute(stmt_tp) + tp_b = result_tp.scalars().first() + self.assertIsNotNone(tp_b) + self.assertEqual(tp_b.path, storage_path) + self.assertIsNone(tp_b.ready_at) + + async def test_create_prefetch_job_db_queues_copy_job(self): + async with self.session_maker() as session: + asset, _, b2 = await self._set_a_finalized_asset(session) + storage_path = storage_backend_lib.get_storage_path(b2, asset.path) + await assets.create_prefetch_job( + session, + asset, + backend=b2, + storage_path=storage_path, + client_keep_alive_interval=datetime.timedelta(seconds=600), + ) + + stmt_tp = select(db_schema.TierPath).filter_by( + asset_uuid=asset.asset_uuid, storage_backend_id=b2.id + ) + result_tp = await session.execute(stmt_tp) + tp_b = result_tp.scalars().first() + self.assertIsNotNone(tp_b) + + stmt_job = select(db_schema.AssetJob).filter_by( + asset_uuid=asset.asset_uuid, + target_tier_path_id=tp_b.id, + request_type=db_schema.RequestType.REQUEST_TYPE_COPY, + ) + result_job = await session.execute(stmt_job) + job = result_job.scalars().first() + self.assertIsNotNone(job) + self.assertEqual(job.status, db_schema.JobStatus.JOB_STATUS_QUEUED) + + @parameterized.named_parameters( + dict( + testcase_name="same_path", + concurrent_path="/mnt/lustre-b/test/path", + attempted_path="/mnt/lustre-b/test/path", + ), + dict( + testcase_name="different_path", + concurrent_path="/mnt/lustre-b/concurrent/path", + attempted_path="/mnt/lustre-b/test/path", + ), + ) + async def test_create_prefetch_job_concurrent_fails_gracefully( + self, concurrent_path: str, attempted_path: str + ): + async with self.session_maker() as session1: + asset1, _, sb2 = await self._set_a_finalized_asset(session1) + asset_uuid = asset1.asset_uuid + b2_id = sb2.id + + async with self.session_maker() as session2: + tp_b = db_schema.TierPath( + asset_uuid=asset_uuid, + storage_backend_id=b2_id, + path=concurrent_path, + ) + session2.add(tp_b) + await session2.flush() + + job_b = db_schema.AssetJob( + asset_uuid=asset_uuid, + request_type=db_schema.RequestType.REQUEST_TYPE_COPY, + status=db_schema.JobStatus.JOB_STATUS_QUEUED, + target_tier_path_id=tp_b.id, + ) + session2.add(job_b) + await session2.commit() + + result = await assets.create_prefetch_job( + session1, + asset1, + backend=sb2, + storage_path=attempted_path, + client_keep_alive_interval=datetime.timedelta(seconds=600), + ) + + self.assertFalse(result.created) + + # Verify DB has only the concurrent TierPath (no new one created) + stmt_tp = select(db_schema.TierPath).filter_by( + asset_uuid=asset_uuid, storage_backend_id=b2_id + ) + result_tp = await session1.execute(stmt_tp) + tps = result_tp.scalars().all() + self.assertLen(tps, 1) + self.assertEqual(tps[0].path, concurrent_path) + + async def test_create_prefetch_job_sets_expires_at_and_uuid(self): + async with self.session_maker() as session: + asset, _, b2 = await self._set_a_finalized_asset(session) + storage_path = storage_backend_lib.get_storage_path(b2, asset.path) + result = await assets.create_prefetch_job( + session, + asset, + backend=b2, + storage_path=storage_path, + client_keep_alive_interval=datetime.timedelta(seconds=600), + ) + updated_asset = result.asset + self.assertIsNotNone(updated_asset) + + tp_b = next( + tp + for tp in updated_asset.tier_paths + if tp.storage_backend_id == b2.id + ) + self.assertIsNotNone(tp_b.tier_path_uuid) + self.assertIsNotNone(tp_b.expires_at) + + async def test_prefetch_keep_alive_success(self): + async with self.session_maker() as session: + asset, _, b2 = await self._set_a_finalized_asset(session) + storage_path = storage_backend_lib.get_storage_path(b2, asset.path) + result = await assets.create_prefetch_job( + session, + asset, + backend=b2, + storage_path=storage_path, + client_keep_alive_interval=datetime.timedelta(seconds=600), + ) + updated_asset = result.asset + self.assertIsNotNone(updated_asset) + tp_b = next( + tp + for tp in updated_asset.tier_paths + if tp.storage_backend_id == b2.id + ) + initial_expires_at = tp_b.expires_at + + extended_asset = await assets.prefetch_keep_alive( + session, + tier_path_uuid=tp_b.tier_path_uuid, + interval=datetime.timedelta(seconds=1200), + ) + self.assertIsNotNone(extended_asset) + tp_b_extended = next( + tp + for tp in extended_asset.tier_paths + if tp.storage_backend_id == b2.id + ) + self.assertGreater(tp_b_extended.expires_at, initial_expires_at) + + async def test_prefetch_keep_alive_not_found(self): + async with self.session_maker() as session: + result = await assets.prefetch_keep_alive( + session, + tier_path_uuid="non-existent-uuid", + interval=datetime.timedelta(seconds=1200), + ) + self.assertIsNone(result) + + async def test_prefetch_keep_alive_no_op_when_permanent(self): + async with self.session_maker() as session: + asset, _, b2 = await self._set_a_finalized_asset(session) + storage_path = storage_backend_lib.get_storage_path(b2, asset.path) + result = await assets.create_prefetch_job( + session, + asset, + backend=b2, + storage_path=storage_path, + client_keep_alive_interval=datetime.timedelta(seconds=600), + ) + updated_asset = result.asset + self.assertIsNotNone(updated_asset) + tp_b = next( + tp + for tp in updated_asset.tier_paths + if tp.storage_backend_id == b2.id + ) + tp_b.expires_at = None + await session.commit() + + extended_asset = await assets.prefetch_keep_alive( + session, + tier_path_uuid=tp_b.tier_path_uuid, + interval=datetime.timedelta(seconds=1200), + ) + self.assertIsNotNone(extended_asset) + tp_b_extended = next( + tp + for tp in extended_asset.tier_paths + if tp.storage_backend_id == b2.id + ) + self.assertIsNone(tp_b_extended.expires_at) + + async def test_prefetch_keep_alive_only_extends(self): + async with self.session_maker() as session: + asset, _, b2 = await self._set_a_finalized_asset(session) + storage_path = storage_backend_lib.get_storage_path(b2, asset.path) + result = await assets.create_prefetch_job( + session, + asset, + backend=b2, + storage_path=storage_path, + client_keep_alive_interval=datetime.timedelta(seconds=600), + ) + updated_asset = result.asset + self.assertIsNotNone(updated_asset) + tp_b = next( + tp + for tp in updated_asset.tier_paths + if tp.storage_backend_id == b2.id + ) + initial_expires_at = tp_b.expires_at + + extended_asset = await assets.prefetch_keep_alive( + session, + tier_path_uuid=tp_b.tier_path_uuid, + interval=datetime.timedelta(seconds=10), + ) + self.assertIsNotNone(extended_asset) + tp_b_extended = next( + tp + for tp in extended_asset.tier_paths + if tp.storage_backend_id == b2.id + ) + self.assertEqual(tp_b_extended.expires_at, initial_expires_at) + + async def test_prefetch_keep_alive_fails_if_deleting(self): + async with self.session_maker() as session: + asset, _, b2 = await self._set_a_finalized_asset(session) + storage_path = storage_backend_lib.get_storage_path(b2, asset.path) + result = await assets.create_prefetch_job( + session, + asset, + backend=b2, + storage_path=storage_path, + client_keep_alive_interval=datetime.timedelta(seconds=600), + ) + updated_asset = result.asset + self.assertIsNotNone(updated_asset) + tp_b = next( + tp + for tp in updated_asset.tier_paths + if tp.storage_backend_id == b2.id + ) + + db_job = db_schema.AssetJob( + asset_uuid=asset.asset_uuid, + request_type=db_schema.RequestType.REQUEST_TYPE_DELETE_FROM_ALL_TIERS, + status=db_schema.JobStatus.JOB_STATUS_QUEUED, + ) + session.add(db_job) + await session.commit() + + with self.assertRaisesRegex( + assets.DeletionPendingError, "marked for deletion" + ): + await assets.prefetch_keep_alive( + session, + tier_path_uuid=tp_b.tier_path_uuid, + interval=datetime.timedelta(seconds=1200), + ) + + async def test_prefetch_keep_alive_fails_if_instance_deleting(self): + async with self.session_maker() as session: + asset, _, b2 = await self._set_a_finalized_asset(session) + storage_path = storage_backend_lib.get_storage_path(b2, asset.path) + result = await assets.create_prefetch_job( + session, + asset, + backend=b2, + storage_path=storage_path, + client_keep_alive_interval=datetime.timedelta(seconds=600), + ) + updated_asset = result.asset + self.assertIsNotNone(updated_asset) + tp_b = next( + tp + for tp in updated_asset.tier_paths + if tp.storage_backend_id == b2.id + ) + + db_job = db_schema.AssetJob( + asset_uuid=asset.asset_uuid, + request_type=db_schema.RequestType.REQUEST_TYPE_DELETE_FROM_INSTANCE, + status=db_schema.JobStatus.JOB_STATUS_QUEUED, + target_tier_path_id=tp_b.id, + ) + session.add(db_job) + await session.commit() + + with self.assertRaisesRegex( + assets.DeletionPendingError, "marked for deletion" + ): + await assets.prefetch_keep_alive( + session, + tier_path_uuid=tp_b.tier_path_uuid, + interval=datetime.timedelta(seconds=1200), + ) + + async def test_create_prefetch_job_fails_if_deleting(self): + async with self.session_maker() as session: + asset, _, b2 = await self._set_a_finalized_asset(session) + storage_path = storage_backend_lib.get_storage_path(b2, asset.path) + + db_job = db_schema.AssetJob( + asset_uuid=asset.asset_uuid, + request_type=db_schema.RequestType.REQUEST_TYPE_DELETE_FROM_ALL_TIERS, + status=db_schema.JobStatus.JOB_STATUS_QUEUED, + ) + session.add(db_job) + await session.commit() + + with self.assertRaisesRegex( + assets.DeletionPendingError, "marked for deletion" + ): + await assets.create_prefetch_job( + session, + asset, + backend=b2, + storage_path=storage_path, + client_keep_alive_interval=datetime.timedelta(seconds=600), + ) + + async def test_is_delete_pending_true(self): + async with self.session_maker() as session: + asset, _, _ = await self._set_a_finalized_asset(session) + + # Insert a pending delete job + db_job = db_schema.AssetJob( + asset_uuid=asset.asset_uuid, + request_type=db_schema.RequestType.REQUEST_TYPE_DELETE_FROM_ALL_TIERS, + status=db_schema.JobStatus.JOB_STATUS_QUEUED, + ) + session.add(db_job) + await session.commit() + + self.assertTrue(await assets.is_delete_pending(session, asset.asset_uuid)) + + async def test_is_delete_pending_false(self): + async with self.session_maker() as session: + asset, _, _ = await self._set_a_finalized_asset(session) + # No job inserted + self.assertFalse( + await assets.is_delete_pending(session, asset.asset_uuid) + ) + + async def test_is_tier_path_delete_pending_true(self): + async with self.session_maker() as session: + asset, _, b2 = await self._set_a_finalized_asset(session) + storage_path = storage_backend_lib.get_storage_path(b2, asset.path) + result = await assets.create_prefetch_job( + session, + asset, + backend=b2, + storage_path=storage_path, + client_keep_alive_interval=datetime.timedelta(seconds=600), + ) + updated_asset = result.asset + self.assertIsNotNone(updated_asset) + assert updated_asset is not None + tp_b = next( + tp + for tp in updated_asset.tier_paths + if tp.storage_backend_id == b2.id + ) + + # Insert a pending delete from instance job + db_job = db_schema.AssetJob( + asset_uuid=asset.asset_uuid, + request_type=db_schema.RequestType.REQUEST_TYPE_DELETE_FROM_INSTANCE, + status=db_schema.JobStatus.JOB_STATUS_QUEUED, + target_tier_path_id=tp_b.id, + ) + session.add(db_job) + await session.commit() + + self.assertTrue( + await assets.is_tier_path_delete_pending( + session, asset_uuid=asset.asset_uuid, tier_path_id=tp_b.id + ) + ) + + async def test_is_tier_path_delete_pending_false(self): + async with self.session_maker() as session: + asset, _, b2 = await self._set_a_finalized_asset(session) + storage_path = storage_backend_lib.get_storage_path(b2, asset.path) + result = await assets.create_prefetch_job( + session, + asset, + backend=b2, + storage_path=storage_path, + client_keep_alive_interval=datetime.timedelta(seconds=600), + ) + updated_asset = result.asset + self.assertIsNotNone(updated_asset) + assert updated_asset is not None + tp_b = next( + tp + for tp in updated_asset.tier_paths + if tp.storage_backend_id == b2.id + ) + # No job inserted + self.assertFalse( + await assets.is_tier_path_delete_pending( + session, asset_uuid=asset.asset_uuid, tier_path_id=tp_b.id + ) + ) + if __name__ == "__main__": absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/auth.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/auth.py index 18a8a3f92..ac79d4198 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/auth.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/auth.py @@ -93,8 +93,34 @@ async def verify_lustre_permissions(token: str | None, path: str) -> bool: return token is not None +async def has_read_permission( + token: str | None, + *, + backend: db_schema.StorageBackend, + path: str, +) -> bool: + """Checks whether bearer token possesses permission scopes for storage read. + + Args: + token: The OAuth token of the caller. + backend: The StorageBackend target. + path: The destination path on the backend. + + Returns: + True if read permission is granted, False otherwise. + """ + if backend.backend_type == db_schema.BackendType.BACKEND_TYPE_GCS: + return await verify_gcs_permissions(token, path, ["storage.objects.get"]) + elif backend.backend_type == db_schema.BackendType.BACKEND_TYPE_LUSTRE: + return await verify_lustre_permissions(token, path) + else: + logging.warning("Unknown backend type: %s", backend.backend_type) + return False + + async def has_write_permission( token: str | None, + *, backend: db_schema.StorageBackend, path: str, ) -> bool: diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/auth_test.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/auth_test.py index 2f8404fef..d66f93299 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/auth_test.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/auth_test.py @@ -53,47 +53,95 @@ async def test_get_oauth_token_malformed_header(self): self.assertIsNone(token) @parameterized.named_parameters( - ( - "gcs_with_token", - db_schema.BackendType.BACKEND_TYPE_GCS, - "valid-token", - True, + dict( + testcase_name="gcs_with_token", + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://bucket", + token="valid-token", + expected=True, ), - ( - "gcs_no_token", - db_schema.BackendType.BACKEND_TYPE_GCS, - None, - False, + dict( + testcase_name="gcs_no_token", + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://bucket", + token=None, + expected=False, ), - ( - "lustre_with_token", - db_schema.BackendType.BACKEND_TYPE_LUSTRE, - "valid-token", - True, + dict( + testcase_name="lustre_with_token", + backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + prefix="/mnt/lustre", + token="valid-token", + expected=True, ), - ( - "lustre_no_token", - db_schema.BackendType.BACKEND_TYPE_LUSTRE, - None, - False, + dict( + testcase_name="lustre_no_token", + backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + prefix="/mnt/lustre", + token=None, + expected=False, ), ) - async def test_has_write_permission(self, backend_type, token, expected): - backend = db_schema.StorageBackend( - backend_type=backend_type, - prefix=( - "gs://bucket" - if backend_type == db_schema.BackendType.BACKEND_TYPE_GCS - else "/mnt/lustre" - ), + async def test_has_write_permission( + self, backend_type, prefix, token, expected + ): + backend = db_schema.StorageBackend(backend_type=backend_type, prefix=prefix) + result = await auth.has_write_permission( + token, backend=backend, path="path" ) - result = await auth.has_write_permission(token, backend, "path") self.assertEqual(result, expected) async def test_has_write_permission_unknown_backend(self): backend = mock.create_autospec(db_schema.StorageBackend, instance=True) backend.backend_type = "UNKNOWN" - result = await auth.has_write_permission("token", backend, "path") + result = await auth.has_write_permission( + "token", backend=backend, path="path" + ) + self.assertFalse(result) + + @parameterized.named_parameters( + dict( + testcase_name="gcs_with_token", + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://bucket", + token="valid-token", + expected=True, + ), + dict( + testcase_name="gcs_no_token", + backend_type=db_schema.BackendType.BACKEND_TYPE_GCS, + prefix="gs://bucket", + token=None, + expected=False, + ), + dict( + testcase_name="lustre_with_token", + backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + prefix="/mnt/lustre", + token="valid-token", + expected=True, + ), + dict( + testcase_name="lustre_no_token", + backend_type=db_schema.BackendType.BACKEND_TYPE_LUSTRE, + prefix="/mnt/lustre", + token=None, + expected=False, + ), + ) + async def test_has_read_permission( + self, backend_type, prefix, token, expected + ): + backend = db_schema.StorageBackend(backend_type=backend_type, prefix=prefix) + result = await auth.has_read_permission(token, backend=backend, path="path") + self.assertEqual(result, expected) + + async def test_has_read_permission_unknown_backend(self): + backend = mock.create_autospec(db_schema.StorageBackend, instance=True) + backend.backend_type = "UNKNOWN" + result = await auth.has_read_permission( + "token", backend=backend, path="path" + ) self.assertFalse(result) diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py index 7be47da76..9f0e75c1c 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema.py @@ -298,6 +298,7 @@ class TierPath(Base): ready_at: Timestamp when the asset became available at this tier path. expires_at: Timestamp when the asset is scheduled to expire from this tier path. + tier_path_uuid: A unique identifier for this tier path. asset: SQLAlchemy relationship to the `Asset` object. storage_backend: SQLAlchemy relationship to the `StorageBackend` object. """ @@ -324,6 +325,12 @@ class TierPath(Base): expires_at = sqlalchemy.Column( sqlalchemy.DateTime(timezone=True), nullable=True ) + tier_path_uuid = sqlalchemy.Column( + sqlalchemy.String, + unique=True, + nullable=False, + default=lambda: str(uuid.uuid4()), + ) asset = sqlalchemy.orm.relationship("Asset", back_populates="tier_paths") storage_backend = sqlalchemy.orm.relationship( diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py index cd06151c5..9fba14663 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/db_schema_test.py @@ -164,6 +164,11 @@ async def test_add_tier_path(self) -> None: tp1.storage_backend.multi_regions, ["us-central1", "us-east1"], ) + self.assertIsNotNone(tp0.tier_path_uuid) + self.assertIsNotNone(tp1.tier_path_uuid) + self.assertLen(tp0.tier_path_uuid, 36) + self.assertLen(tp1.tier_path_uuid, 36) + self.assertNotEqual(tp0.tier_path_uuid, tp1.tier_path_uuid) async def test_add_tier_path_fails_multiple_locations(self) -> None: async with self.session_maker() as session: diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/proto/tiering_service.proto b/checkpoint/orbax/checkpoint/experimental/tiering_service/proto/tiering_service.proto index e8c1aaf00..eab98a32b 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/proto/tiering_service.proto +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/proto/tiering_service.proto @@ -55,6 +55,9 @@ message TierPath { // call prefetch again. When this field is absent, indicating it won't expire // (e.g., in GCS, the asset path won't expire). google.protobuf.Timestamp expires_at = 5; + + // A unique identifier for this tier path. + string tier_path_uuid = 6; } message Asset { @@ -115,7 +118,7 @@ message PrefetchResponse { } message PrefetchKeepAliveRequest { - string uuid = 1; + string tier_path_uuid = 1; } message PrefetchKeepAliveResponse { diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/server.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/server.py index ced97904b..ff445494c 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/server.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/server.py @@ -62,7 +62,7 @@ def __init__(self, config: tiering_service_pb2.ServerConfig): self._engine = db_lib.get_async_engine(self._config) self._session_maker = sessionmaker( self._engine, - # Required for async session usage + # Required for async session usage. expire_on_commit=False, class_=AsyncSession, ) @@ -107,7 +107,7 @@ async def Reserve( ) return tiering_service_pb2.ReserveResponse() - # Find the closest level 0 backend + # Find the closest level 0 backend. if self._level0_backends is None: await context.abort( grpc.StatusCode.FAILED_PRECONDITION, "Servicer not initialized" @@ -116,7 +116,7 @@ async def Reserve( backend = storage_backend.locate_closest_backend( self._level0_backends, request.zone, request.region ) - if not backend: + if backend is None: zone_val = request.zone region_val = request.region await context.abort( @@ -129,7 +129,9 @@ async def Reserve( # Calculate resolved path and check GCS permission. storage_path = storage_backend.get_storage_path(backend, request.path) token = await auth.get_oauth_token(context) - if not await auth.has_write_permission(token, backend, storage_path): + if not await auth.has_write_permission( + token, backend=backend, path=storage_path + ): logging.warning( "Permission denied for Reserve on storage path: %s", storage_path ) @@ -146,8 +148,10 @@ async def Reserve( session, request, backend, self._config ) except ValueError as e: + logging.exception("Failed to reserve asset for path: %s", request.path) + error_msg = str(e) await context.abort( - grpc.StatusCode.INTERNAL, f"Failed to reserve asset: {e}" + grpc.StatusCode.INTERNAL, f"Failed to reserve asset: {error_msg}" ) return tiering_service_pb2.ReserveResponse() @@ -172,7 +176,7 @@ async def ReserveKeepAlive( seconds=self._config.client_keep_alive_interval_seconds ), ) - if not db_asset: + if db_asset is None: logging.warning("ReserveKeepAlive: Asset not found: %s", request.uuid) await context.abort(grpc.StatusCode.NOT_FOUND, "Asset not found") return tiering_service_pb2.ReserveKeepAliveResponse() @@ -193,7 +197,7 @@ async def Finalize( async with self._session_scope() as session: db_assets = await assets.fetch_asset_by_uuid(session, request.uuid) db_asset = db_assets[0] if db_assets else None - if not db_asset: + if db_asset is None: logging.warning("Finalize: Asset not found: %s", request.uuid) await context.abort(grpc.StatusCode.NOT_FOUND, "Asset not found") return tiering_service_pb2.FinalizeResponse() @@ -212,7 +216,6 @@ async def Finalize( ) return tiering_service_pb2.FinalizeResponse() - # Verify write permission before finalizing. tier_path = next(iter(db_asset.tier_paths), None) if tier_path is None: logging.warning( @@ -227,7 +230,7 @@ async def Finalize( return tiering_service_pb2.FinalizeResponse() if not await auth.has_write_permission( - token, tier_path.storage_backend, tier_path.path + token, backend=tier_path.storage_backend, path=tier_path.path ): logging.warning( "Permission denied for Finalize on path: %s", @@ -244,14 +247,15 @@ async def Finalize( try: db_asset = await assets.finalize_asset(session, db_asset) - if not db_asset: - # This is unlikely to happen since we just finalized the asset + if db_asset is None: + # This is unlikely to happen since we just finalized the asset. raise ValueError("Asset not found after finalize") except ValueError as e: logging.exception("Finalize failed for UUID: %s", request.uuid) + error_msg = str(e) await context.abort( grpc.StatusCode.FAILED_PRECONDITION, - f"Failed to finalize asset: {e}", + f"Failed to finalize asset: {error_msg}", ) return tiering_service_pb2.FinalizeResponse() @@ -270,12 +274,142 @@ async def Prefetch( grpc.StatusCode.INVALID_ARGUMENT, "No location specified" ) return tiering_service_pb2.PrefetchResponse() - # TODO: b/503445654 - Trigger async copy to closest storage tier to user. - await context.abort( - grpc.StatusCode.UNIMPLEMENTED, "Prefetch Not Implemented" + if self._level0_backends is None: + await context.abort( + grpc.StatusCode.FAILED_PRECONDITION, "Servicer not initialized" + ) + return tiering_service_pb2.PrefetchResponse() + closest_backend = storage_backend.locate_closest_backend( + self._level0_backends, request.zone, request.region ) - return tiering_service_pb2.PrefetchResponse() + if closest_backend is None: + # No closest backend available to requestor. + zone_val = request.zone or None + region_val = request.region or None + await context.abort( + grpc.StatusCode.NOT_FOUND, + f"No level 0 storage backend found for zone:{zone_val} /" + f" region:{region_val}", + ) + return tiering_service_pb2.PrefetchResponse() + + async with self._session_scope() as session: + + db_assets = await assets.fetch_asset_by_identifier( + session, + asset_uuid=request.uuid if request.HasField("uuid") else None, + path=request.path if request.HasField("path") else None, + inclusive_filter=[ + db_schema.AssetState.ASSET_STATE_STORED, + ], + ) + db_asset = db_assets[0] if db_assets else None + if db_asset is None: + identifier = request.uuid if request.HasField("uuid") else request.path + logging.warning( + "Prefetch: Asset not found or not STORED: %s", identifier + ) + await context.abort(grpc.StatusCode.NOT_FOUND, "Asset not found") + return tiering_service_pb2.PrefetchResponse() + + for tp in db_asset.tier_paths: + if tp.storage_backend_id == closest_backend.id: + # TODO: b/503445463 - Extend the expiration of the existing TierPath + # if needed. + logging.info( + "Prefetch: Asset %s already has a TierPath on backend %s", + db_asset.asset_uuid, + closest_backend.id, + ) + return tiering_service_pb2.PrefetchResponse( + asset=assets.proto_from_db_asset(db_asset), + keep_alive_interval_seconds=( + self._config.client_keep_alive_interval_seconds + ), + ) + + # No existing TierPath, we need to prefetch + storage_path = storage_backend.get_storage_path( + closest_backend, db_asset.path + ) + + token = await auth.get_oauth_token(context) + + # Check read permissions on all existing tier paths. + if db_asset.tier_paths: + first_tp = db_asset.tier_paths[0] + if not await auth.has_read_permission( + token, backend=first_tp.storage_backend, path=first_tp.path + ): + logging.warning( + "Permission denied for Prefetch on source path: %s", first_tp.path + ) + backend_name = storage_backend.get_backend_name( + first_tp.storage_backend + ) + await context.abort( + grpc.StatusCode.PERMISSION_DENIED, + f"Insufficient read permissions on source {backend_name}", + ) + return tiering_service_pb2.PrefetchResponse() + + # Check read permission on the target level 0 backend. + if not await auth.has_read_permission( + token, backend=closest_backend, path=storage_path + ): + logging.warning( + "Permission denied for Prefetch on storage path: %s", storage_path + ) + backend_name = storage_backend.get_backend_name(closest_backend) + await context.abort( + grpc.StatusCode.PERMISSION_DENIED, + f"Insufficient read permissions on target {backend_name}", + ) + return tiering_service_pb2.PrefetchResponse() + + try: + result = await assets.create_prefetch_job( + session, + db_asset, + backend=closest_backend, + storage_path=storage_path, + client_keep_alive_interval=datetime.timedelta( + seconds=self._config.client_keep_alive_interval_seconds + ), + ) + db_asset = result.asset + except assets.DeletionPendingError: + identifier = request.uuid if request.HasField("uuid") else request.path + error_msg = f"Prefetch: Deletion is pending for asset {identifier}" + logging.exception(error_msg) + await context.abort(grpc.StatusCode.FAILED_PRECONDITION, error_msg) + return tiering_service_pb2.PrefetchResponse() + except ValueError: + logging.exception( + "Failed to create prefetch job for identifier: %s", + request.uuid if request.HasField("uuid") else request.path, + ) + await context.abort( + grpc.StatusCode.INTERNAL, + "Failed to create prefetch job", + ) + return tiering_service_pb2.PrefetchResponse() + + if db_asset is None: + identifier = request.uuid if request.HasField("uuid") else request.path + logging.warning( + "Prefetch: Asset not found after create_prefetch_job for" + " identifier: %s", + identifier, + ) + await context.abort(grpc.StatusCode.NOT_FOUND, "Asset not found") + return tiering_service_pb2.PrefetchResponse() + + return tiering_service_pb2.PrefetchResponse( + asset=assets.proto_from_db_asset(db_asset), + keep_alive_interval_seconds=self._config.client_keep_alive_interval_seconds, + ) async def PrefetchKeepAlive( self, @@ -283,12 +417,50 @@ async def PrefetchKeepAlive( context: grpc.aio.ServicerContext, ) -> tiering_service_pb2.PrefetchKeepAliveResponse: """Signals that the client is still reading/waiting for promotion.""" - logging.info("PrefetchKeepAlive requested for UUID: %s", request.uuid) - - await context.abort( - grpc.StatusCode.UNIMPLEMENTED, "PrefetchKeepAlive Not Implemented" + logging.info( + "PrefetchKeepAlive requested for tier_path_uuid: %s", + request.tier_path_uuid, ) - return tiering_service_pb2.PrefetchKeepAliveResponse() + + async with self._session_scope() as session: + try: + db_asset = await assets.prefetch_keep_alive( + session, + tier_path_uuid=request.tier_path_uuid, + interval=datetime.timedelta( + seconds=self._config.client_keep_alive_interval_seconds + ), + ) + except assets.DeletionPendingError: + error_msg = ( + "PrefetchKeepAlive: Deletion is pending for TierPath " + f"{request.tier_path_uuid}" + ) + logging.warning(error_msg) + await context.abort(grpc.StatusCode.FAILED_PRECONDITION, error_msg) + return tiering_service_pb2.PrefetchKeepAliveResponse() + except Exception: # pylint: disable=broad-except + logging.exception( + "Failed to keep alive prefetch for TierPath: %s", + request.tier_path_uuid, + ) + await context.abort( + grpc.StatusCode.INTERNAL, + "Failed to keep alive prefetch", + ) + return tiering_service_pb2.PrefetchKeepAliveResponse() + if db_asset is None: + logging.warning( + "PrefetchKeepAlive: TierPath not found: %s", + request.tier_path_uuid, + ) + await context.abort(grpc.StatusCode.NOT_FOUND, "TierPath not found") + return tiering_service_pb2.PrefetchKeepAliveResponse() + + return tiering_service_pb2.PrefetchKeepAliveResponse( + asset=assets.proto_from_db_asset(db_asset), + keep_alive_interval_seconds=self._config.client_keep_alive_interval_seconds, + ) async def Delete( self, diff --git a/checkpoint/orbax/checkpoint/experimental/tiering_service/server_test.py b/checkpoint/orbax/checkpoint/experimental/tiering_service/server_test.py index 605f9717a..a40ad4d84 100644 --- a/checkpoint/orbax/checkpoint/experimental/tiering_service/server_test.py +++ b/checkpoint/orbax/checkpoint/experimental/tiering_service/server_test.py @@ -20,10 +20,13 @@ import aiosqlite # pylint: disable=unused-import import greenlet # pylint: disable=unused-import import grpc +from orbax.checkpoint.experimental.tiering_service import auth from orbax.checkpoint.experimental.tiering_service import db_lib +from orbax.checkpoint.experimental.tiering_service import db_schema from orbax.checkpoint.experimental.tiering_service import server from orbax.checkpoint.experimental.tiering_service import server_config from orbax.checkpoint.experimental.tiering_service.proto import tiering_service_pb2 +from sqlalchemy.future import select from google.protobuf import timestamp_pb2 @@ -82,6 +85,49 @@ def _setup_config(self, config_dict): config.db_connection_str = f"sqlite+aiosqlite:///{tmp_file.full_path}" return config + def _get_multi_lustre_config(self, zone_prefixes): + """Sets up config with multiple Lustre backends and a default GCS backend.""" + storage_backends_config = [] + for zone, suffix in zone_prefixes: + storage_backends_config.append({ + "level": 0, + "backend_type": "BACKEND_TYPE_LUSTRE", + "prefix": f"/mnt/lustre-{suffix}", + "zone": zone, + }) + # Add the default GCS backend. + storage_backends_config.append({ + "level": 1, + "backend_type": "BACKEND_TYPE_GCS", + "prefix": "gs://my-bucket", + "region": "us-central1", + }) + return self._setup_config({"storage_backends": storage_backends_config}) + + async def _setup_servicer_and_asset(self): + """Sets up a servicer with 2 Lustre backends and reserves/finalizes an asset.""" + config = self._get_multi_lustre_config([ + ("us-central1-a", "a"), + ("us-central1-b", "b"), + ]) + servicer = server.TieringServiceServicer(config) + await server.setup_storage_backends(config) + await servicer.initialize() + self.addAsyncCleanup(servicer.close) + + # Reserve and Finalize on A. + reserve_res = await servicer.Reserve( + tiering_service_pb2.ReserveRequest( + path="test/path", user="test-user", zone="us-central1-a" + ), + self.context, + ) + asset_uuid = reserve_res.asset.uuid + await servicer.Finalize( + tiering_service_pb2.FinalizeRequest(uuid=asset_uuid), self.context + ) + return servicer, asset_uuid + async def test_reserve_success(self): request = tiering_service_pb2.ReserveRequest( path="test/path", @@ -148,7 +194,7 @@ async def test_finalize_success(self): async def test_finalize_permission_denied(self): asset_uuid = await self._reserve_asset() - # Remove token from context to simulate missing auth + # Remove token from context to simulate missing auth. self.context.invocation_metadata.return_value = () finalize_req = tiering_service_pb2.FinalizeRequest(uuid=asset_uuid) @@ -164,7 +210,7 @@ async def test_finalize_already_finalized_raises_failed_precondition(self): tiering_service_pb2.FinalizeRequest(uuid=asset_uuid), self.context ) - # Try to finalize again + # Try to finalize again. finalize_req = tiering_service_pb2.FinalizeRequest(uuid=asset_uuid) await self.servicer.Finalize(finalize_req, self.context) @@ -174,15 +220,6 @@ async def test_finalize_already_finalized_raises_failed_precondition(self): " ASSET_STATE_ACTIVE_WRITE", ) - async def test_delete_unimplemented(self): - asset_uuid = await self._reserve_asset() - await self.servicer.Delete( - tiering_service_pb2.DeleteRequest(uuid=asset_uuid), self.context - ) - self.context.abort.assert_called_once_with( - grpc.StatusCode.UNIMPLEMENTED, "Delete Not Implemented" - ) - async def test_info_success(self): asset_uuid = await self._reserve_asset() response = await self.servicer.Info( @@ -190,11 +227,161 @@ async def test_info_success(self): ) self.assertEqual(response.assets[0].uuid, asset_uuid) - async def test_prefetch_unimplemented(self): + async def test_prefetch_success_rpc_response(self): + servicer, asset_uuid = await self._setup_servicer_and_asset() + + prefetch_req = tiering_service_pb2.PrefetchRequest( + uuid=asset_uuid, zone="us-central1-b" + ) + prefetch_res = await servicer.Prefetch(prefetch_req, self.context) + + self.assertEqual(prefetch_res.asset.uuid, asset_uuid) + paths = [tp.path for tp in prefetch_res.asset.tier_paths] + self.assertCountEqual( + paths, + ["/mnt/lustre-a/test/path", "/mnt/lustre-b/test/path"], + ) + + async def test_prefetch_success_db_job_creation(self): + servicer, asset_uuid = await self._setup_servicer_and_asset() + + prefetch_req = tiering_service_pb2.PrefetchRequest( + uuid=asset_uuid, zone="us-central1-b" + ) + await servicer.Prefetch(prefetch_req, self.context) + + async with servicer._session_scope() as session: + stmt = select(db_schema.AssetJob).filter_by( + asset_uuid=asset_uuid, + request_type=db_schema.RequestType.REQUEST_TYPE_COPY, + ) + result = await session.execute(stmt) + jobs = result.scalars().all() + self.assertLen(jobs, 1) + self.assertEqual(jobs[0].status, db_schema.JobStatus.JOB_STATUS_QUEUED) + target_tp_id = jobs[0].target_tier_path_id + + stmt_tp = select(db_schema.TierPath).filter_by( + asset_uuid=asset_uuid, path="/mnt/lustre-b/test/path" + ) + result_tp = await session.execute(stmt_tp) + tp_b = result_tp.scalars().first() + self.assertIsNotNone(tp_b) + self.assertEqual(target_tp_id, tp_b.id) + self.assertIsNone(tp_b.ready_at) + + async def test_prefetch_idempotent(self): + servicer, asset_uuid = await self._setup_servicer_and_asset() + + # 1. Prefetch from B (first time). + await servicer.Prefetch( + tiering_service_pb2.PrefetchRequest( + uuid=asset_uuid, zone="us-central1-b" + ), + self.context, + ) + + # 2. Prefetch from B (second time). + await servicer.Prefetch( + tiering_service_pb2.PrefetchRequest( + uuid=asset_uuid, zone="us-central1-b" + ), + self.context, + ) + + # Verify only ONE job was created. + async with servicer._session_scope() as session: + stmt = select(db_schema.AssetJob).filter_by( + asset_uuid=asset_uuid, + request_type=db_schema.RequestType.REQUEST_TYPE_COPY, + ) + result = await session.execute(stmt) + jobs = result.scalars().all() + self.assertLen(jobs, 1) + + async def test_prefetch_already_ready(self): + # If we prefetch to the same zone where it was reserved and finalized, + # it should be already ready, so no job should be created. asset_uuid = await self._reserve_asset() - # Finalize the asset to STORED state so it is eligible for prefetching - finalize_req = tiering_service_pb2.FinalizeRequest(uuid=asset_uuid) - await self.servicer.Finalize(finalize_req, self.context) + await self.servicer.Finalize( + tiering_service_pb2.FinalizeRequest(uuid=asset_uuid), self.context + ) + + # Prefetch to the same zone "us-central1-a" + prefetch_req = tiering_service_pb2.PrefetchRequest( + uuid=asset_uuid, zone="us-central1-a" + ) + response = await self.servicer.Prefetch(prefetch_req, self.context) + + # Verify response + self.assertEqual(response.asset.uuid, asset_uuid) + self.assertLen(response.asset.tier_paths, 1) + self.assertIsNotNone( + response.asset.tier_paths[0].ready_at.ToDatetime() + ) + + # Verify NO job was created + async with self.servicer._session_scope() as session: + stmt = select(db_schema.AssetJob).filter_by( + asset_uuid=asset_uuid, + request_type=db_schema.RequestType.REQUEST_TYPE_COPY, + ) + result = await session.execute(stmt) + jobs = result.scalars().all() + self.assertEmpty(jobs) + + async def test_prefetch_permission_denied(self): + servicer, asset_uuid = await self._setup_servicer_and_asset() + + # Remove token from context to simulate missing auth. + self.context.invocation_metadata.return_value = () + + # Prefetch from B (should fail). + prefetch_req = tiering_service_pb2.PrefetchRequest( + uuid=asset_uuid, zone="us-central1-b" + ) + await servicer.Prefetch(prefetch_req, self.context) + + self.context.abort.assert_called_once_with( + grpc.StatusCode.PERMISSION_DENIED, + "Insufficient read permissions on source Lustre", + ) + + async def test_prefetch_permission_denied_on_target(self): + servicer, asset_uuid = await self._setup_servicer_and_asset() + + # Mock has_read_permission to succeed for source (lustre-a) but fail for + # target (lustre-b). + async def mock_has_read_permission(unused_token, *, backend, path): + del backend # Unused. + if "lustre-b" in path: + return False + return True + + with mock.patch.object( + auth, + "has_read_permission", + autospec=True, + side_effect=mock_has_read_permission, + ): + prefetch_req = tiering_service_pb2.PrefetchRequest( + uuid=asset_uuid, zone="us-central1-b" + ) + await servicer.Prefetch(prefetch_req, self.context) + + self.context.abort.assert_called_once_with( + grpc.StatusCode.PERMISSION_DENIED, + "Insufficient read permissions on target Lustre", + ) + + @parameterized.named_parameters( + dict(testcase_name="asset_not_finalized", reserve_asset=True), + dict(testcase_name="asset_does_not_exist", reserve_asset=False), + ) + async def test_prefetch_not_found(self, reserve_asset): + asset_uuid = "invalid-uuid" + if reserve_asset: + asset_uuid = await self._reserve_asset() prefetch_req = tiering_service_pb2.PrefetchRequest( uuid=asset_uuid, zone="us-central1-a" @@ -202,17 +389,131 @@ async def test_prefetch_unimplemented(self): await self.servicer.Prefetch(prefetch_req, self.context) self.context.abort.assert_called_once_with( - grpc.StatusCode.UNIMPLEMENTED, "Prefetch Not Implemented" + grpc.StatusCode.NOT_FOUND, "Asset not found" ) - async def test_prefetch_keep_alive_unimplemented(self): + async def test_prefetch_backend_not_found(self): asset_uuid = await self._reserve_asset() - req = tiering_service_pb2.PrefetchKeepAliveRequest(uuid=asset_uuid) - await self.servicer.PrefetchKeepAlive(req, self.context) + prefetch_req = tiering_service_pb2.PrefetchRequest( + uuid=asset_uuid, zone="us-central1-b" + ) + await self.servicer.Prefetch(prefetch_req, self.context) self.context.abort.assert_called_once_with( - grpc.StatusCode.UNIMPLEMENTED, "PrefetchKeepAlive Not Implemented" + grpc.StatusCode.NOT_FOUND, + "No level 0 storage backend found for zone:us-central1-b / region:None", + ) + + async def test_prefetch_keep_alive_grpc_success(self): + servicer, asset_uuid = await self._setup_servicer_and_asset() + + # 1. Prefetch on B (triggers job & creates TierPath on B) + prefetch_res = await servicer.Prefetch( + tiering_service_pb2.PrefetchRequest( + uuid=asset_uuid, zone="us-central1-b" + ), + self.context, + ) + self.assertLen(prefetch_res.asset.tier_paths, 2) + tp_b = next( + tp for tp in prefetch_res.asset.tier_paths if "lustre-b" in tp.path + ) + self.assertTrue(tp_b.HasField("expires_at")) + initial_expires_at = tp_b.expires_at.ToDatetime() + + # 2. Call PrefetchKeepAlive + keep_alive_req = tiering_service_pb2.PrefetchKeepAliveRequest( + tier_path_uuid=tp_b.tier_path_uuid + ) + keep_alive_res = await servicer.PrefetchKeepAlive( + keep_alive_req, self.context + ) + + # Verify TTL is extended + tp_b_extended = next( + tp for tp in keep_alive_res.asset.tier_paths if "lustre-b" in tp.path + ) + self.assertGreater( + tp_b_extended.expires_at.ToDatetime(), initial_expires_at + ) + + async def test_prefetch_keep_alive_not_found_fails(self): + req = tiering_service_pb2.PrefetchKeepAliveRequest( + tier_path_uuid="non-existent-uuid" + ) + await self.servicer.PrefetchKeepAlive(req, self.context) + self.context.abort.assert_called_once_with( + grpc.StatusCode.NOT_FOUND, "TierPath not found" + ) + + async def test_prefetch_keep_alive_multi_zone_isolation(self): + config = self._get_multi_lustre_config([ + ("us-central1-a", "a"), + ("us-central1-b", "b"), + ("us-central1-c", "c"), + ]) + servicer = server.TieringServiceServicer(config) + await server.setup_storage_backends(config) + await servicer.initialize() + self.addAsyncCleanup(servicer.close) + + # 1. Reserve and Finalize on C + reserve_res = await servicer.Reserve( + tiering_service_pb2.ReserveRequest( + path="test/path", user="test-user", zone="us-central1-c" + ), + self.context, + ) + asset_uuid = reserve_res.asset.uuid + await servicer.Finalize( + tiering_service_pb2.FinalizeRequest(uuid=asset_uuid), self.context + ) + + # 2. Prefetch to A (Zone A) + prefetch_res_a = await servicer.Prefetch( + tiering_service_pb2.PrefetchRequest( + uuid=asset_uuid, zone="us-central1-a" + ), + self.context, + ) + tp_a = next( + tp for tp in prefetch_res_a.asset.tier_paths if "lustre-a" in tp.path + ) + expires_at_a_initial = tp_a.expires_at.ToDatetime() + + # 3. Prefetch to B (Zone B) + prefetch_res_b = await servicer.Prefetch( + tiering_service_pb2.PrefetchRequest( + uuid=asset_uuid, zone="us-central1-b" + ), + self.context, + ) + tp_b = next( + tp for tp in prefetch_res_b.asset.tier_paths if "lustre-b" in tp.path + ) + expires_at_b_initial = tp_b.expires_at.ToDatetime() + + # 4. Extend Zone A's TTL + keep_alive_req = tiering_service_pb2.PrefetchKeepAliveRequest( + tier_path_uuid=tp_a.tier_path_uuid + ) + keep_alive_res = await servicer.PrefetchKeepAlive( + keep_alive_req, self.context + ) + + # Verify Zone A's TTL is extended + tp_a_extended = next( + tp for tp in keep_alive_res.asset.tier_paths if "lustre-a" in tp.path + ) + self.assertGreater( + tp_a_extended.expires_at.ToDatetime(), expires_at_a_initial + ) + + # Verify Zone B's TTL remains strictly unchanged (isolation) + tp_b_post = next( + tp for tp in keep_alive_res.asset.tier_paths if "lustre-b" in tp.path ) + self.assertEqual(tp_b_post.expires_at.ToDatetime(), expires_at_b_initial) async def test_prefetch_invalid_argument(self): asset_uuid = await self._reserve_asset() @@ -237,7 +538,7 @@ def test_tier_path_presence(self): self.assertTrue(tp.HasField("expires_at")) async def test_reserve_permission_denied(self): - # Remove token from context to simulate missing auth + # Remove token from context to simulate missing auth. self.context.invocation_metadata.return_value = () request = tiering_service_pb2.ReserveRequest(