Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions checkpoint/orbax/checkpoint/experimental/tiering_service/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""

from collections.abc import Collection, Sequence
import dataclasses
import datetime

from absl import logging
Expand All @@ -35,6 +36,24 @@
from google.protobuf import timestamp_pb2


class DeletionPendingError(ValueError):
"""Raised when an operation is attempted on an asset/TierPath marked for deletion."""


@dataclasses.dataclass
class CreatePrefetchJobResult:
"""Result of creating a prefetch job.

Attributes:
asset: The updated asset, or None if not found.
created: True if a new job was successfully created, False if it failed due
to a concurrent insert or there is already an existing job queued.
"""

asset: db_schema.Asset | None
created: bool


def _proto_from_db_tier_path(
tier_path: db_schema.TierPath,
) -> tiering_service_pb2.TierPath:
Expand Down Expand Up @@ -92,6 +111,7 @@ def _get_location_kwargs(sb: db_schema.StorageBackend):
storage_backend=proto_storage_backend,
ready_at=ready_at_pb,
expires_at=expires_at_pb,
tier_path_uuid=tier_path.tier_path_uuid,
)


Expand Down Expand Up @@ -381,3 +401,202 @@ async def finalize_asset(
await session.commit()
await session.refresh(db_asset, attribute_names=["updated_at"])
return db_asset


async def is_delete_pending(session: AsyncSession, asset_uuid: str) -> bool:
"""Checks if there is a pending delete job for the asset."""
stmt = (
select(db_schema.AssetJob)
.filter_by(
asset_uuid=asset_uuid,
request_type=db_schema.RequestType.REQUEST_TYPE_DELETE_FROM_ALL_TIERS,
)
.where(
db_schema.AssetJob.status.in_([
db_schema.JobStatus.JOB_STATUS_QUEUED,
db_schema.JobStatus.JOB_STATUS_PROCESSING,
])
)
)
result = await session.execute(stmt)
return bool(result.scalars().first())


async def is_tier_path_delete_pending(
session: AsyncSession,
*,
asset_uuid: str,
tier_path_id: int,
) -> bool:
"""Checks if there is a pending delete job for the specific TierPath."""
stmt = (
select(db_schema.AssetJob)
.filter_by(
asset_uuid=asset_uuid,
target_tier_path_id=tier_path_id,
request_type=db_schema.RequestType.REQUEST_TYPE_DELETE_FROM_INSTANCE,
)
.where(
db_schema.AssetJob.status.in_([
db_schema.JobStatus.JOB_STATUS_QUEUED,
db_schema.JobStatus.JOB_STATUS_PROCESSING,
])
)
)
result = await session.execute(stmt)
return bool(result.scalars().first())


async def create_prefetch_job(
session: AsyncSession,
db_asset: db_schema.Asset,
*,
backend: db_schema.StorageBackend,
storage_path: str,
client_keep_alive_interval: datetime.timedelta,
) -> CreatePrefetchJobResult:
"""Queues a prefetch job for the given asset to the target backend.

This function executes atomically in a single transaction. It creates both
the `TierPath` and the corresponding `AssetJob` together, ensuring that we
never commit a "dangling" `TierPath` without an associated prefetch job. If
either operation fails (e.g. due to concurrent insertion conflicts), the
entire transaction is rolled back.

Args:
session: The database session (active transaction).
db_asset: The asset to prefetch.
backend: The target storage backend (level 0).
storage_path: The storage path to use for the new TierPath.
client_keep_alive_interval: The interval to set for the initial expires_at
of the TierPath.

Returns:
A CreatePrefetchJobResult containing the updated asset and a boolean
indicating whether a new job was created.

Raises:
DeletionPendingError: If the asset is already marked for deletion.
"""

# Check if there is already a preceding delete job
if await is_delete_pending(session, db_asset.asset_uuid):
raise DeletionPendingError(
f"Cannot prefetch asset {db_asset.asset_uuid} because it is marked for"
" deletion."
)

logging.info(
"Prefetch: Creating new pending TierPath and job for asset %s and"
" backend %s",
db_asset.asset_uuid,
backend.id,
)
new_tp = db_schema.TierPath(
storage_backend=backend,
path=storage_path,
expires_at=calculate_expires_at(client_keep_alive_interval),
)
db_asset.tier_paths.append(new_tp)

db_job = db_schema.AssetJob(
asset_uuid=db_asset.asset_uuid,
request_type=db_schema.RequestType.REQUEST_TYPE_COPY,
status=db_schema.JobStatus.JOB_STATUS_QUEUED,
target_tier_path=new_tp,
)
session.add(db_job)

asset_uuid = db_asset.asset_uuid
backend_id = backend.id
try:
await session.commit()
except IntegrityError:
await session.rollback()
logging.debug(
"Prefetch: Concurrent insert detected for asset %s and backend %s,"
" rolling back",
asset_uuid,
backend_id,
)
db_assets = await fetch_asset_by_uuid(session, asset_uuid)
return CreatePrefetchJobResult(
asset=(db_assets[0] if db_assets else None), created=False
)

await session.refresh(db_asset, attribute_names=["updated_at"])
return CreatePrefetchJobResult(asset=db_asset, created=True)


async def prefetch_keep_alive(
session: AsyncSession,
*,
tier_path_uuid: str,
interval: datetime.timedelta,
) -> db_schema.Asset | None:
"""Extend the TierPath's expiration timestamp.

Args:
session: The database session.
tier_path_uuid: The UUID of the TierPath to update.
interval: The new timeout interval.

Returns:
The updated Asset object, or None if the TierPath was not found.

Raises:
DeletionPendingError: If the asset associated with the TierPath is marked
for deletion, or if the specific TierPath instance is marked for
deletion.
"""
stmt = select(db_schema.TierPath).filter_by(tier_path_uuid=tier_path_uuid)
result = await session.execute(stmt)
tp = result.scalars().first()
if tp is None:
return None

if await is_delete_pending(session, tp.asset_uuid):
raise DeletionPendingError(f"Asset {tp.asset_uuid} is marked for deletion.")

if await is_tier_path_delete_pending(
session, asset_uuid=tp.asset_uuid, tier_path_id=tp.id
):
raise DeletionPendingError(
f"TierPath {tier_path_uuid} is marked for deletion."
)

if tp.expires_at is None:
logging.debug(
"TierPath %s has no expires_at (still copying or permanent), no-op",
tier_path_uuid,
)
db_assets = await fetch_asset_by_uuid(session, tp.asset_uuid)
return db_assets[0] if db_assets else None

new_expires_at = calculate_expires_at(interval)
existing_expires_at = tp.expires_at
compared_expires_at = (
existing_expires_at.replace(tzinfo=datetime.timezone.utc)
if existing_expires_at.tzinfo is None
else existing_expires_at
)
if new_expires_at > compared_expires_at:
logging.debug(
"Extending TierPath %s expires_at from %s to %s",
tier_path_uuid,
existing_expires_at,
new_expires_at,
)
tp.expires_at = new_expires_at
await session.commit()
else:
logging.debug(
"New expires_at %s is not longer than existing %s for TierPath %s,"
" no-op",
new_expires_at,
existing_expires_at,
tier_path_uuid,
)

db_assets = await fetch_asset_by_uuid(session, tp.asset_uuid)
return db_assets[0] if db_assets else None
Loading
Loading