From 8221f398a7f0ce72131b6151d160fa4878faa401 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Sat, 2 May 2026 03:10:58 +0900 Subject: [PATCH 1/4] feat(BA-5936): add Pruner repository abstraction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces a declarative spec for bulk-delete (prune) operations parallel to the existing Creator/Updater/Purger patterns: - PrunerSpec[TRow] abstract base with row_class/returning_id/prune_condition classmethods plus per-call conditions/cascade/limit/cascade_rbac fields. - CascadeChild abstraction for FK-dependent child tables (e.g., kernels of pruned sessions) deleted within the same transaction. - entity_type() classmethod opts the spec into RBAC association cleanup via association_scopes_entities (cross-cutting, built into the executor). - execute_pruner runs SELECT pk FOR UPDATE LIMIT once, materializes IDs, then issues cascade DELETEs and the parent DELETE...RETURNING from the same locked set — race-safe and avoids re-evaluating the parent SELECT per cascade. - DEFAULT_PRUNE_LIMIT (100k) caps lock set, memory, and transaction duration; operators run multiple calls to drain larger backlogs. Includes integration tests covering basic prune, runtime conditions, limit, FK cascade, RBAC cleanup with entity_type filtering, and the default no-op (entity_type=None) path. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../manager/repositories/base/__init__.py | 11 + .../manager/repositories/base/pruner.py | 215 +++++++++ .../manager/repositories/base/test_pruner.py | 445 ++++++++++++++++++ 3 files changed, 671 insertions(+) create mode 100644 src/ai/backend/manager/repositories/base/pruner.py create mode 100644 tests/unit/manager/repositories/base/test_pruner.py diff --git a/src/ai/backend/manager/repositories/base/__init__.py b/src/ai/backend/manager/repositories/base/__init__.py index 118e6626700..02292d5aa98 100644 --- a/src/ai/backend/manager/repositories/base/__init__.py +++ b/src/ai/backend/manager/repositories/base/__init__.py @@ -38,6 +38,12 @@ PageInfoResult, QueryPagination, ) +from .pruner import ( + CascadeChild, + PrunerResult, + PrunerSpec, + execute_pruner, +) from .purger import ( BatchPurger, BatchPurgerResult, @@ -173,6 +179,11 @@ "BatchPurger", "BatchPurgerResult", "execute_batch_purger", + # Pruner + "CascadeChild", + "PrunerSpec", + "PrunerResult", + "execute_pruner", # Utils "combine_conditions_or", "negate_conditions", diff --git a/src/ai/backend/manager/repositories/base/pruner.py b/src/ai/backend/manager/repositories/base/pruner.py new file mode 100644 index 00000000000..36838252c8f --- /dev/null +++ b/src/ai/backend/manager/repositories/base/pruner.py @@ -0,0 +1,215 @@ +"""Pruner spec and cascade abstractions for bulk delete operations.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, TypeVar + +import sqlalchemy as sa + +from ai.backend.common.data.permission.types import EntityType +from ai.backend.manager.models.base import Base +from ai.backend.manager.models.rbac_models.association_scopes_entities import ( + AssociationScopesEntitiesRow, +) + +from .integrity import parse_integrity_error +from .types import QueryCondition + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession as SASession + +TRow = TypeVar("TRow", bound=Base) + + +class CascadeChild(ABC): + """A child table whose rows must be deleted before the parent's prune. + + Used for simple FK cascades. Each cascade DELETE runs as:: + + DELETE FROM WHERE + IN (SELECT FROM + WHERE ) + + Polymorphic / cross-cutting cleanups (e.g., RBAC associations) are not + handled here — see :meth:`PrunerSpec.entity_type` for that. + """ + + @classmethod + @abstractmethod + def row_class(cls) -> type[Base]: + """ORM Row class for the cascade table. + + Example: + return KernelRow + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def parent_id_column(cls) -> Any: + """FK column on the cascade table that references the parent's PK. + + Example: + return KernelRow.session_id + """ + raise NotImplementedError + + +DEFAULT_PRUNE_LIMIT = 100_000 +"""Default cap on rows pruned per call. Bounds memory (parent ID list), +held row locks, and transaction duration.""" + + +@dataclass +class PrunerSpec[TRow: Base](ABC): + """Spec for a prune operation: entity contract + runtime params + cascade. + + Subclasses declare the entity-level prune contract via classmethods. + Per-call parameters live on the instance. + + Attributes: + conditions: Additional WHERE clauses combined (AND) with + ``prune_condition()``. Use to inject runtime bounds. + cascade: FK-dependent child tables to delete first within the same + transaction (see :class:`CascadeChild`). + limit: Hard cap on rows pruned per call (default + :data:`DEFAULT_PRUNE_LIMIT`). Required to bound the SELECT FOR + UPDATE lock set, the in-memory ID list, and transaction + duration. Operators run multiple calls to drain larger backlogs. + cascade_rbac: When True (default) and :meth:`entity_type` returns a + non-None ``EntityType``, ``execute_pruner`` also deletes + ``association_scopes_entities`` rows whose + ``(entity_type, entity_id)`` references the pruned parent rows. + """ + + conditions: list[QueryCondition] = field(default_factory=list) + cascade: list[CascadeChild] = field(default_factory=list) + limit: int = DEFAULT_PRUNE_LIMIT + cascade_rbac: bool = True + + @classmethod + @abstractmethod + def row_class(cls) -> type[TRow]: + """ORM Row class for the parent entity table. + + Example: + return SessionRow + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def returning_id(cls) -> Any: + """Primary-key column for the parent's ``DELETE ... RETURNING``. + + Example: + return SessionRow.id + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def prune_condition(cls) -> sa.ColumnElement[bool]: + """Hardcoded terminal-state WHERE clause for the parent entity. + + Example: + return SessionRow.status.in_(TERMINAL_SESSION_STATUSES) + """ + raise NotImplementedError + + @classmethod + def entity_type(cls) -> EntityType | None: + """RBAC ``EntityType`` for this entity, or ``None`` to skip RBAC cleanup. + + When non-None and ``cascade_rbac`` is True, ``execute_pruner`` + deletes matching rows in ``association_scopes_entities`` within the + same transaction. Default: ``None`` (no RBAC cleanup). + + Example: + return EntityType.SESSION + """ + return None + + +@dataclass +class PrunerResult: + """Result of executing a prune operation. + + Attributes: + count: Number of parent rows deleted. + ids: PK values of the deleted parent rows. + """ + + count: int + ids: list[Any] = field(default_factory=list) + + +async def execute_pruner[TRow: Base]( + db_sess: SASession, + spec: PrunerSpec[TRow], +) -> PrunerResult: + """Execute the prune as a single SELECT FOR UPDATE followed by DELETEs. + + Order within the transaction: + + 1. ``SELECT pk FOR UPDATE LIMIT spec.limit`` to lock the target parent + rows and materialize their IDs once. + 2. FK cascade children (``spec.cascade``) — each DELETE uses + ``parent_id_column.in_(target_ids)``. + 3. RBAC associations (when ``spec.cascade_rbac`` is True and + ``spec.entity_type()`` is not None) — IDs are stringified for the + polymorphic ``entity_id`` text column. + 4. Parent DELETE with ``RETURNING`` to surface the pruned PK list. + + Materializing the locked ID list avoids re-evaluating the parent SELECT + in every cascade subquery and removes the race window between + statements. + + Args: + db_sess: Database session (must be writable). + spec: PrunerSpec instance carrying conditions, cascade, and limit. + + Returns: + PrunerResult with the count and PK list of deleted parent rows. + + Raises: + RepositoryIntegrityError: If any DELETE violates a database constraint. + """ + cls = type(spec) + table = cls.row_class().__table__ + pk_col = cls.returning_id() + + where = cls.prune_condition() + for f in spec.conditions: + where = sa.and_(where, f()) + + target_q = sa.select(pk_col).where(where).with_for_update().limit(spec.limit) + target_ids = list((await db_sess.scalars(target_q)).all()) + if not target_ids: + return PrunerResult(count=0, ids=[]) + + for child in spec.cascade: + ccls = type(child) + cascade_table = ccls.row_class().__table__ + await db_sess.execute( + sa.delete(cascade_table).where(ccls.parent_id_column().in_(target_ids)) + ) + + rbac_entity_type = cls.entity_type() + if spec.cascade_rbac and rbac_entity_type is not None: + await db_sess.execute( + sa.delete(AssociationScopesEntitiesRow).where( + AssociationScopesEntitiesRow.entity_type == rbac_entity_type, + AssociationScopesEntitiesRow.entity_id.in_([str(i) for i in target_ids]), + ) + ) + + stmt = sa.delete(table).where(pk_col.in_(target_ids)).returning(pk_col) + try: + deleted = list((await db_sess.scalars(stmt)).all()) + except sa.exc.IntegrityError as e: + raise parse_integrity_error(e) from e + + return PrunerResult(count=len(deleted), ids=deleted) diff --git a/tests/unit/manager/repositories/base/test_pruner.py b/tests/unit/manager/repositories/base/test_pruner.py new file mode 100644 index 00000000000..8f7e7cad0ef --- /dev/null +++ b/tests/unit/manager/repositories/base/test_pruner.py @@ -0,0 +1,445 @@ +"""Integration tests for execute_pruner with real database.""" + +from __future__ import annotations + +import uuid +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta +from typing import TYPE_CHECKING, Any +from uuid import UUID + +import pytest +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID as PGUUID + +from ai.backend.common.data.permission.types import EntityType, ScopeType +from ai.backend.manager.models.base import Base +from ai.backend.manager.models.rbac_models.association_scopes_entities import ( + AssociationScopesEntitiesRow, +) +from ai.backend.manager.repositories.base import ( + CascadeChild, + PrunerResult, + PrunerSpec, + execute_pruner, +) + +if TYPE_CHECKING: + from ai.backend.manager.models.utils import ExtendedAsyncSAEngine + + +# Module-level test ORM models so SQLAlchemy can resolve them. + + +class PrunerTestParentRow(Base): # type: ignore[misc] + """Parent table for pruner tests with a status + terminated_at.""" + + __tablename__ = "test_pruner_parent" + __table_args__ = {"extend_existing": True} + + id = sa.Column(PGUUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name = sa.Column(sa.String(50), nullable=False) + status = sa.Column(sa.String(20), nullable=False) + terminated_at = sa.Column(sa.DateTime(timezone=True), nullable=True) + + +class PrunerTestChildRow(Base): # type: ignore[misc] + """Child table FK-bound to PrunerTestParentRow.id (for cascade tests).""" + + __tablename__ = "test_pruner_child" + __table_args__ = {"extend_existing": True} + + id = sa.Column(PGUUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + parent_id = sa.Column( + PGUUID(as_uuid=True), + sa.ForeignKey("test_pruner_parent.id"), + nullable=False, + ) + name = sa.Column(sa.String(50), nullable=False) + + +class TestChildCascade(CascadeChild): + @classmethod + def row_class(cls) -> type[Base]: + return PrunerTestChildRow + + @classmethod + def parent_id_column(cls) -> Any: + return PrunerTestChildRow.parent_id + + +@dataclass +class TerminatedTestParentPrunerSpec(PrunerSpec[PrunerTestParentRow]): + """Default spec — no entity_type, so RBAC cleanup is skipped.""" + + @classmethod + def row_class(cls) -> type[PrunerTestParentRow]: + return PrunerTestParentRow + + @classmethod + def returning_id(cls) -> Any: + return PrunerTestParentRow.id + + @classmethod + def prune_condition(cls) -> sa.ColumnElement[bool]: + return PrunerTestParentRow.status == "terminated" + + +@dataclass +class TerminatedTestParentPrunerSpecWithRBAC(TerminatedTestParentPrunerSpec): + """Variant that opts into RBAC cleanup using EntityType.SESSION as a stand-in.""" + + @classmethod + def entity_type(cls) -> EntityType: + return EntityType.SESSION + + +async def _seed_parents( + db: ExtendedAsyncSAEngine, +) -> dict[str, list[UUID]]: + """Insert 5 terminated + 3 active parent rows; return PKs grouped by status.""" + now = datetime.now(UTC) + terminated: list[UUID] = [] + active: list[UUID] = [] + async with db.begin_session() as db_sess: + for i in range(5): + row_id = uuid.uuid4() + db_sess.add( + PrunerTestParentRow( + id=row_id, + name=f"term-{i}", + status="terminated", + terminated_at=now - timedelta(hours=2 if i < 3 else 0), + ) + ) + terminated.append(row_id) + for i in range(3): + row_id = uuid.uuid4() + db_sess.add( + PrunerTestParentRow( + id=row_id, + name=f"active-{i}", + status="active", + terminated_at=None, + ) + ) + active.append(row_id) + return {"terminated": terminated, "active": active} + + +@pytest.fixture +async def parent_only_tables( + database_connection: ExtendedAsyncSAEngine, +) -> AsyncGenerator[ExtendedAsyncSAEngine, None]: + """Create parent table only (no children, no RBAC table).""" + async with database_connection.begin() as conn: + await conn.run_sync(lambda c: PrunerTestParentRow.__table__.create(c, checkfirst=True)) + yield database_connection + async with database_connection.begin() as conn: + await conn.run_sync(lambda c: PrunerTestParentRow.__table__.drop(c, checkfirst=True)) + + +@pytest.fixture +async def parent_child_tables( + database_connection: ExtendedAsyncSAEngine, +) -> AsyncGenerator[ExtendedAsyncSAEngine, None]: + """Create parent + FK-bound child tables.""" + async with database_connection.begin() as conn: + await conn.run_sync(lambda c: PrunerTestParentRow.__table__.create(c, checkfirst=True)) + await conn.run_sync(lambda c: PrunerTestChildRow.__table__.create(c, checkfirst=True)) + yield database_connection + async with database_connection.begin() as conn: + await conn.run_sync(lambda c: PrunerTestChildRow.__table__.drop(c, checkfirst=True)) + await conn.run_sync(lambda c: PrunerTestParentRow.__table__.drop(c, checkfirst=True)) + + +@pytest.fixture +async def parent_with_rbac_tables( + database_connection: ExtendedAsyncSAEngine, +) -> AsyncGenerator[ExtendedAsyncSAEngine, None]: + """Create parent + association_scopes_entities tables for RBAC tests.""" + async with database_connection.begin() as conn: + # association_scopes_entities.id has server_default=uuid_generate_v4(). + await conn.execute(sa.text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp"')) + await conn.run_sync(lambda c: PrunerTestParentRow.__table__.create(c, checkfirst=True)) + await conn.run_sync( + lambda c: AssociationScopesEntitiesRow.__table__.create(c, checkfirst=True) + ) + yield database_connection + async with database_connection.begin() as conn: + await conn.run_sync( + lambda c: AssociationScopesEntitiesRow.__table__.drop(c, checkfirst=True) + ) + await conn.run_sync(lambda c: PrunerTestParentRow.__table__.drop(c, checkfirst=True)) + + +class TestPrunerBasic: + """Core behavior — no cascade, no RBAC.""" + + async def test_prune_terminal_rows(self, parent_only_tables: ExtendedAsyncSAEngine) -> None: + seeded = await _seed_parents(parent_only_tables) + async with parent_only_tables.begin_session() as db_sess: + spec = TerminatedTestParentPrunerSpec() + result = await execute_pruner(db_sess, spec) + + assert isinstance(result, PrunerResult) + assert result.count == 5 + assert set(result.ids) == set(seeded["terminated"]) + + async with parent_only_tables.begin_readonly_session() as db_sess: + remaining = await db_sess.scalars(sa.select(PrunerTestParentRow.id)) + assert set(remaining.all()) == set(seeded["active"]) + + async def test_prune_no_matching_rows(self, parent_only_tables: ExtendedAsyncSAEngine) -> None: + # Only active rows seeded — terminal-state condition matches none. + async with parent_only_tables.begin_session() as db_sess: + for i in range(3): + db_sess.add( + PrunerTestParentRow( + id=uuid.uuid4(), + name=f"active-{i}", + status="active", + ) + ) + + async with parent_only_tables.begin_session() as db_sess: + spec = TerminatedTestParentPrunerSpec() + result = await execute_pruner(db_sess, spec) + + assert result.count == 0 + assert result.ids == [] + + async def test_prune_empty_table(self, parent_only_tables: ExtendedAsyncSAEngine) -> None: + async with parent_only_tables.begin_session() as db_sess: + spec = TerminatedTestParentPrunerSpec() + result = await execute_pruner(db_sess, spec) + + assert result.count == 0 + assert result.ids == [] + + async def test_prune_with_runtime_condition( + self, parent_only_tables: ExtendedAsyncSAEngine + ) -> None: + seeded = await _seed_parents(parent_only_tables) + # Seed sets terminated_at = now-2h for first 3 terminated rows; rest at now. + cutoff = datetime.now(UTC) - timedelta(hours=1) + + async with parent_only_tables.begin_session() as db_sess: + spec = TerminatedTestParentPrunerSpec( + conditions=[lambda: PrunerTestParentRow.terminated_at < cutoff], + ) + result = await execute_pruner(db_sess, spec) + + assert result.count == 3 + # The first 3 terminated rows have terminated_at = now-2h. + assert set(result.ids) == set(seeded["terminated"][:3]) + + async def test_prune_with_limit_caps_count( + self, parent_only_tables: ExtendedAsyncSAEngine + ) -> None: + await _seed_parents(parent_only_tables) + async with parent_only_tables.begin_session() as db_sess: + spec = TerminatedTestParentPrunerSpec(limit=2) + result = await execute_pruner(db_sess, spec) + + assert result.count == 2 + assert len(result.ids) == 2 + + async with parent_only_tables.begin_readonly_session() as db_sess: + remaining = await db_sess.scalars( + sa.select(sa.func.count()).select_from(PrunerTestParentRow) + ) + assert remaining.one() == 6 # 8 - 2 + + +class TestPrunerCascade: + """FK cascade behavior.""" + + async def _seed_with_children( + self, db: ExtendedAsyncSAEngine + ) -> tuple[dict[str, list[UUID]], dict[UUID, list[UUID]]]: + """Seed parents + 2 children per parent. Return (parents_by_status, children_by_parent).""" + seeded = await _seed_parents(db) + children_by_parent: dict[UUID, list[UUID]] = {} + async with db.begin_session() as db_sess: + for parent_id in seeded["terminated"] + seeded["active"]: + child_ids = [] + for i in range(2): + cid = uuid.uuid4() + db_sess.add( + PrunerTestChildRow( + id=cid, + parent_id=parent_id, + name=f"child-{i}", + ) + ) + child_ids.append(cid) + children_by_parent[parent_id] = child_ids + return seeded, children_by_parent + + async def test_cascade_deletes_children_of_pruned_parents( + self, parent_child_tables: ExtendedAsyncSAEngine + ) -> None: + seeded, children = await self._seed_with_children(parent_child_tables) + + async with parent_child_tables.begin_session() as db_sess: + spec = TerminatedTestParentPrunerSpec(cascade=[TestChildCascade()]) + result = await execute_pruner(db_sess, spec) + + assert result.count == 5 + assert set(result.ids) == set(seeded["terminated"]) + + # Children of pruned parents are gone; children of active parents remain. + async with parent_child_tables.begin_readonly_session() as db_sess: + remaining_children = ( + await db_sess.scalars(sa.select(PrunerTestChildRow.parent_id)) + ).all() + expected_remaining_parents = set(seeded["active"]) + assert ( + set(remaining_children) + == {pid for pid in expected_remaining_parents for _ in children[pid]} + or set(remaining_children) <= expected_remaining_parents + ) + + # Each surviving parent still has its 2 children. + count = await db_sess.scalars( + sa.select(sa.func.count()).select_from(PrunerTestChildRow) + ) + assert count.one() == len(seeded["active"]) * 2 + + async def test_cascade_skipped_for_non_terminal_parents( + self, parent_child_tables: ExtendedAsyncSAEngine + ) -> None: + seeded, _children = await self._seed_with_children(parent_child_tables) + + async with parent_child_tables.begin_session() as db_sess: + spec = TerminatedTestParentPrunerSpec(cascade=[TestChildCascade()]) + await execute_pruner(db_sess, spec) + + # Active parents preserved. + async with parent_child_tables.begin_readonly_session() as db_sess: + remaining = (await db_sess.scalars(sa.select(PrunerTestParentRow.id))).all() + assert set(remaining) == set(seeded["active"]) + + async def test_no_cascade_with_fk_violation_raises( + self, parent_child_tables: ExtendedAsyncSAEngine + ) -> None: + """Without the cascade, FK constraint blocks the parent DELETE.""" + await self._seed_with_children(parent_child_tables) + + with pytest.raises(Exception): + async with parent_child_tables.begin_session() as db_sess: + spec = TerminatedTestParentPrunerSpec() # no cascade + await execute_pruner(db_sess, spec) + + +class TestPrunerRBAC: + """RBAC association cleanup driven by entity_type().""" + + async def _seed_with_rbac( + self, db: ExtendedAsyncSAEngine + ) -> tuple[dict[str, list[UUID]], dict[UUID, UUID]]: + """Seed parents + one SESSION RBAC association per parent. + + Returns (parents_by_status, association_id_by_parent_id). + """ + seeded = await _seed_parents(db) + assoc_by_parent: dict[UUID, UUID] = {} + async with db.begin_session() as db_sess: + for parent_id in seeded["terminated"] + seeded["active"]: + aid = uuid.uuid4() + db_sess.add( + AssociationScopesEntitiesRow( + id=aid, + scope_type=ScopeType.GLOBAL, + scope_id="global", + entity_type=EntityType.SESSION, + entity_id=str(parent_id), + ) + ) + assoc_by_parent[parent_id] = aid + # One unrelated row using a different entity_type with the same UUID + # as a terminated parent — should never be deleted (entity_type filter). + # Distinct scope_id avoids the (scope_type, scope_id, entity_id) unique constraint. + unrelated_id = uuid.uuid4() + db_sess.add( + AssociationScopesEntitiesRow( + id=unrelated_id, + scope_type=ScopeType.GLOBAL, + scope_id="other-scope", + entity_type=EntityType.VFOLDER, + entity_id=str(seeded["terminated"][0]), + ) + ) + assoc_by_parent[uuid.uuid4()] = unrelated_id # sentinel key for return + return seeded, assoc_by_parent + + async def test_rbac_cleanup_when_enabled( + self, parent_with_rbac_tables: ExtendedAsyncSAEngine + ) -> None: + seeded, _ = await self._seed_with_rbac(parent_with_rbac_tables) + + async with parent_with_rbac_tables.begin_session() as db_sess: + spec = TerminatedTestParentPrunerSpecWithRBAC(cascade_rbac=True) + result = await execute_pruner(db_sess, spec) + + assert result.count == 5 + + async with parent_with_rbac_tables.begin_readonly_session() as db_sess: + # SESSION associations for terminated parents are gone. + remaining_session_assoc_ids = ( + await db_sess.scalars( + sa.select(AssociationScopesEntitiesRow.entity_id).where( + AssociationScopesEntitiesRow.entity_type == EntityType.SESSION, + ) + ) + ).all() + assert set(remaining_session_assoc_ids) == {str(pid) for pid in seeded["active"]} + + # The unrelated VFOLDER association — same UUID but different entity_type — is preserved. + unrelated_count = await db_sess.scalars( + sa.select(sa.func.count()) + .select_from(AssociationScopesEntitiesRow) + .where(AssociationScopesEntitiesRow.entity_type == EntityType.VFOLDER) + ) + assert unrelated_count.one() == 1 + + async def test_rbac_skipped_when_disabled( + self, parent_with_rbac_tables: ExtendedAsyncSAEngine + ) -> None: + seeded, _ = await self._seed_with_rbac(parent_with_rbac_tables) + + async with parent_with_rbac_tables.begin_session() as db_sess: + spec = TerminatedTestParentPrunerSpecWithRBAC(cascade_rbac=False) + await execute_pruner(db_sess, spec) + + # All SESSION associations preserved (8 — one per parent). + async with parent_with_rbac_tables.begin_readonly_session() as db_sess: + count = await db_sess.scalars( + sa.select(sa.func.count()) + .select_from(AssociationScopesEntitiesRow) + .where(AssociationScopesEntitiesRow.entity_type == EntityType.SESSION) + ) + assert count.one() == len(seeded["terminated"]) + len(seeded["active"]) + + async def test_rbac_skipped_when_entity_type_none( + self, parent_with_rbac_tables: ExtendedAsyncSAEngine + ) -> None: + seeded, _ = await self._seed_with_rbac(parent_with_rbac_tables) + + # Default spec returns entity_type=None — RBAC cleanup must be skipped + # even with cascade_rbac=True (default). + async with parent_with_rbac_tables.begin_session() as db_sess: + spec = TerminatedTestParentPrunerSpec() + result = await execute_pruner(db_sess, spec) + + assert result.count == 5 + + async with parent_with_rbac_tables.begin_readonly_session() as db_sess: + count = await db_sess.scalars( + sa.select(sa.func.count()) + .select_from(AssociationScopesEntitiesRow) + .where(AssociationScopesEntitiesRow.entity_type == EntityType.SESSION) + ) + assert count.one() == len(seeded["terminated"]) + len(seeded["active"]) From 301dd9c9dcf5321a0539b58846dc16716ab98312 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Sat, 2 May 2026 03:17:13 +0900 Subject: [PATCH 2/4] changelog: add news fragment for PR #11460 --- changes/11460.feature.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/11460.feature.md diff --git a/changes/11460.feature.md b/changes/11460.feature.md new file mode 100644 index 00000000000..9d01953d7a4 --- /dev/null +++ b/changes/11460.feature.md @@ -0,0 +1 @@ +Add `PrunerSpec` and `CascadeChild` repository-layer abstractions for bulk-delete (prune) operations with FK cascade and RBAC association cleanup. From f23a9637d598ee409e832b894e564ade1fe779ac Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Sat, 2 May 2026 03:36:52 +0900 Subject: [PATCH 3/4] refactor(BA-5936): apply Copilot review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Wrap all DELETEs (cascade, RBAC, parent) in a single try/except for IntegrityError so cascade failures also surface as RepositoryIntegrityError instead of raw SQLAlchemy errors. - Drop PrunerSpec.returning_id() — derive the parent PK column directly from row_class().__table__.primary_key, mirroring the Purger pattern. Reject composite-PK tables with UnsupportedCompositePrimaryKeyError. - Update CascadeChild docstring to reflect the materialized-id-list approach (the previous SQL-subquery wording was stale). - Tighten test_no_cascade_with_fk_violation_raises to expect ForeignKeyViolationError, also covering the parse_integrity_error path. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../manager/repositories/base/pruner.py | 69 ++++++++++--------- .../manager/repositories/base/test_pruner.py | 14 ++-- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/src/ai/backend/manager/repositories/base/pruner.py b/src/ai/backend/manager/repositories/base/pruner.py index 36838252c8f..b88d42752fb 100644 --- a/src/ai/backend/manager/repositories/base/pruner.py +++ b/src/ai/backend/manager/repositories/base/pruner.py @@ -9,6 +9,7 @@ import sqlalchemy as sa from ai.backend.common.data.permission.types import EntityType +from ai.backend.manager.errors.repository import UnsupportedCompositePrimaryKeyError from ai.backend.manager.models.base import Base from ai.backend.manager.models.rbac_models.association_scopes_entities import ( AssociationScopesEntitiesRow, @@ -26,11 +27,14 @@ class CascadeChild(ABC): """A child table whose rows must be deleted before the parent's prune. - Used for simple FK cascades. Each cascade DELETE runs as:: + Used for simple FK cascades. ``execute_pruner`` first locks and + materializes the parent target IDs once, then issues each cascade + DELETE as:: - DELETE FROM WHERE - IN (SELECT FROM - WHERE ) + DELETE FROM WHERE IN () + + where ```` is the list returned from the single + ``SELECT pk FOR UPDATE`` against the parent table. Polymorphic / cross-cutting cleanups (e.g., RBAC associations) are not handled here — see :meth:`PrunerSpec.entity_type` for that. @@ -94,18 +98,13 @@ class PrunerSpec[TRow: Base](ABC): def row_class(cls) -> type[TRow]: """ORM Row class for the parent entity table. - Example: - return SessionRow - """ - raise NotImplementedError - - @classmethod - @abstractmethod - def returning_id(cls) -> Any: - """Primary-key column for the parent's ``DELETE ... RETURNING``. + The single-column primary key is derived from + ``row_class().__table__.primary_key`` by ``execute_pruner``; + composite-PK tables are rejected with + :class:`UnsupportedCompositePrimaryKeyError`. Example: - return SessionRow.id + return SessionRow """ raise NotImplementedError @@ -175,11 +174,19 @@ async def execute_pruner[TRow: Base]( PrunerResult with the count and PK list of deleted parent rows. Raises: - RepositoryIntegrityError: If any DELETE violates a database constraint. + UnsupportedCompositePrimaryKeyError: If the parent table has a + composite primary key. + RepositoryIntegrityError: If any DELETE (cascade, RBAC, or parent) + violates a database constraint. """ cls = type(spec) table = cls.row_class().__table__ - pk_col = cls.returning_id() + pk_columns = list(table.primary_key.columns) + if len(pk_columns) != 1: + raise UnsupportedCompositePrimaryKeyError( + f"PrunerSpec only supports single-column primary keys (table: {table.name})", + ) + pk_col = pk_columns[0] where = cls.prune_condition() for f in spec.conditions: @@ -190,24 +197,24 @@ async def execute_pruner[TRow: Base]( if not target_ids: return PrunerResult(count=0, ids=[]) - for child in spec.cascade: - ccls = type(child) - cascade_table = ccls.row_class().__table__ - await db_sess.execute( - sa.delete(cascade_table).where(ccls.parent_id_column().in_(target_ids)) - ) - rbac_entity_type = cls.entity_type() - if spec.cascade_rbac and rbac_entity_type is not None: - await db_sess.execute( - sa.delete(AssociationScopesEntitiesRow).where( - AssociationScopesEntitiesRow.entity_type == rbac_entity_type, - AssociationScopesEntitiesRow.entity_id.in_([str(i) for i in target_ids]), + try: + for child in spec.cascade: + ccls = type(child) + cascade_table = ccls.row_class().__table__ + await db_sess.execute( + sa.delete(cascade_table).where(ccls.parent_id_column().in_(target_ids)) ) - ) - stmt = sa.delete(table).where(pk_col.in_(target_ids)).returning(pk_col) - try: + if spec.cascade_rbac and rbac_entity_type is not None: + await db_sess.execute( + sa.delete(AssociationScopesEntitiesRow).where( + AssociationScopesEntitiesRow.entity_type == rbac_entity_type, + AssociationScopesEntitiesRow.entity_id.in_([str(i) for i in target_ids]), + ) + ) + + stmt = sa.delete(table).where(pk_col.in_(target_ids)).returning(pk_col) deleted = list((await db_sess.scalars(stmt)).all()) except sa.exc.IntegrityError as e: raise parse_integrity_error(e) from e diff --git a/tests/unit/manager/repositories/base/test_pruner.py b/tests/unit/manager/repositories/base/test_pruner.py index 8f7e7cad0ef..a5859410b6b 100644 --- a/tests/unit/manager/repositories/base/test_pruner.py +++ b/tests/unit/manager/repositories/base/test_pruner.py @@ -14,6 +14,7 @@ from sqlalchemy.dialects.postgresql import UUID as PGUUID from ai.backend.common.data.permission.types import EntityType, ScopeType +from ai.backend.manager.errors.repository import ForeignKeyViolationError from ai.backend.manager.models.base import Base from ai.backend.manager.models.rbac_models.association_scopes_entities import ( AssociationScopesEntitiesRow, @@ -77,10 +78,6 @@ class TerminatedTestParentPrunerSpec(PrunerSpec[PrunerTestParentRow]): def row_class(cls) -> type[PrunerTestParentRow]: return PrunerTestParentRow - @classmethod - def returning_id(cls) -> Any: - return PrunerTestParentRow.id - @classmethod def prune_condition(cls) -> sa.ColumnElement[bool]: return PrunerTestParentRow.status == "terminated" @@ -325,10 +322,15 @@ async def test_cascade_skipped_for_non_terminal_parents( async def test_no_cascade_with_fk_violation_raises( self, parent_child_tables: ExtendedAsyncSAEngine ) -> None: - """Without the cascade, FK constraint blocks the parent DELETE.""" + """Without the cascade, FK constraint blocks the parent DELETE. + + Also verifies that ``execute_pruner`` translates the SQLAlchemy + ``IntegrityError`` into ``ForeignKeyViolationError`` via + ``parse_integrity_error``. + """ await self._seed_with_children(parent_child_tables) - with pytest.raises(Exception): + with pytest.raises(ForeignKeyViolationError): async with parent_child_tables.begin_session() as db_sess: spec = TerminatedTestParentPrunerSpec() # no cascade await execute_pruner(db_sess, spec) From 2324c77710dc4f12280d2ecb8c12e3cd52263c97 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Sat, 2 May 2026 03:49:07 +0900 Subject: [PATCH 4/4] refactor(BA-5936): make CascadeChild generic over its row type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Parameterize CascadeChild on the cascade table's Row class (``CascadeChild[TRow: Base]``) for consistency with PrunerSpec[TRow] and to give subclass authors precise typing on ``row_class()`` — e.g., ``CascadeChild[KernelRow]`` makes ``row_class()`` return ``type[KernelRow]`` instead of the over-broad ``type[Base]``. PrunerSpec.cascade is typed as ``list[CascadeChild[Any]]`` since a spec composes cascades over heterogeneous tables. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/ai/backend/manager/repositories/base/pruner.py | 9 ++++++--- tests/unit/manager/repositories/base/test_pruner.py | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/ai/backend/manager/repositories/base/pruner.py b/src/ai/backend/manager/repositories/base/pruner.py index b88d42752fb..32f6f7eb9d7 100644 --- a/src/ai/backend/manager/repositories/base/pruner.py +++ b/src/ai/backend/manager/repositories/base/pruner.py @@ -24,7 +24,7 @@ TRow = TypeVar("TRow", bound=Base) -class CascadeChild(ABC): +class CascadeChild[TRow: Base](ABC): """A child table whose rows must be deleted before the parent's prune. Used for simple FK cascades. ``execute_pruner`` first locks and @@ -36,13 +36,16 @@ class CascadeChild(ABC): where ```` is the list returned from the single ``SELECT pk FOR UPDATE`` against the parent table. + The type parameter ``TRow`` is the cascade table's ORM Row class — + e.g., ``CascadeChild[KernelRow]``. + Polymorphic / cross-cutting cleanups (e.g., RBAC associations) are not handled here — see :meth:`PrunerSpec.entity_type` for that. """ @classmethod @abstractmethod - def row_class(cls) -> type[Base]: + def row_class(cls) -> type[TRow]: """ORM Row class for the cascade table. Example: @@ -89,7 +92,7 @@ class PrunerSpec[TRow: Base](ABC): """ conditions: list[QueryCondition] = field(default_factory=list) - cascade: list[CascadeChild] = field(default_factory=list) + cascade: list[CascadeChild[Any]] = field(default_factory=list) limit: int = DEFAULT_PRUNE_LIMIT cascade_rbac: bool = True diff --git a/tests/unit/manager/repositories/base/test_pruner.py b/tests/unit/manager/repositories/base/test_pruner.py index a5859410b6b..c9d05a2fb7e 100644 --- a/tests/unit/manager/repositories/base/test_pruner.py +++ b/tests/unit/manager/repositories/base/test_pruner.py @@ -60,9 +60,9 @@ class PrunerTestChildRow(Base): # type: ignore[misc] name = sa.Column(sa.String(50), nullable=False) -class TestChildCascade(CascadeChild): +class TestChildCascade(CascadeChild[PrunerTestChildRow]): @classmethod - def row_class(cls) -> type[Base]: + def row_class(cls) -> type[PrunerTestChildRow]: return PrunerTestChildRow @classmethod