Skip to content

Commit 61c1c85

Browse files
Merge pull request #62 from younesious/feat/61-typing-sqlalchemy-adapter
Add Generic TypeVar Support to SQLAlchemy Adapters
2 parents 8c56081 + b55d2db commit 61c1c85

1 file changed

Lines changed: 14 additions & 11 deletions

File tree

archipy/adapters/base/sqlalchemy/adapters.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import Any, override
2+
from typing import Any, override, TypeVar
33
from uuid import UUID
44

55
from sqlalchemy import Delete, Executable, Result, ScalarResult, Update, func, select
@@ -30,6 +30,9 @@
3030
from archipy.models.types.base_types import FilterOperationType
3131
from archipy.models.types.sort_order_type import SortOrderType
3232

33+
# Generic type variable for BaseEntity subclasses
34+
T = TypeVar('T', bound=BaseEntity)
35+
3336

3437
class SQLAlchemyExceptionHandlerMixin:
3538
"""Mixin providing centralized exception handling for SQLAlchemy operations.
@@ -284,14 +287,14 @@ def get_session(self) -> Session:
284287
return self.session_manager.get_session()
285288

286289
@override
287-
def create(self, entity: BaseEntity) -> BaseEntity | None:
290+
def create(self, entity: T) -> T | None:
288291
"""Create a new entity in the database.
289292
290293
Args:
291294
entity: The entity to create.
292295
293296
Returns:
294-
The created entity with updated attributes.
297+
The created entity with updated attributes, preserving the original type.
295298
296299
Raises:
297300
InvalidEntityTypeError: If the entity type is not a valid SQLAlchemy model.
@@ -318,14 +321,14 @@ def create(self, entity: BaseEntity) -> BaseEntity | None:
318321
return entity
319322

320323
@override
321-
def bulk_create(self, entities: list[BaseEntity]) -> list[BaseEntity] | None:
324+
def bulk_create(self, entities: list[T]) -> list[T] | None:
322325
"""Creates multiple entities in a single database operation.
323326
324327
Args:
325328
entities: List of entities to create.
326329
327330
Returns:
328-
List of created entities with updated attributes.
331+
List of created entities with updated attributes, preserving original types.
329332
330333
Raises:
331334
InvalidEntityTypeError: If any entity is not a valid SQLAlchemy model.
@@ -352,7 +355,7 @@ def bulk_create(self, entities: list[BaseEntity]) -> list[BaseEntity] | None:
352355
return entities
353356

354357
@override
355-
def get_by_uuid(self, entity_type: type, entity_uuid: UUID) -> BaseEntity | None:
358+
def get_by_uuid(self, entity_type: type[T], entity_uuid: UUID) -> T | None:
356359
"""Retrieve an entity by its UUID.
357360
358361
Args:
@@ -593,14 +596,14 @@ def get_session(self) -> AsyncSession:
593596
return self.session_manager.get_session()
594597

595598
@override
596-
async def create(self, entity: BaseEntity) -> BaseEntity | None:
599+
async def create(self, entity: T) -> T | None:
597600
"""Create a new entity in the database.
598601
599602
Args:
600603
entity: The entity to create.
601604
602605
Returns:
603-
The created entity with updated attributes.
606+
The created entity with updated attributes, preserving the original type.
604607
605608
Raises:
606609
InvalidEntityTypeError: If the entity type is not a valid SQLAlchemy model.
@@ -627,14 +630,14 @@ async def create(self, entity: BaseEntity) -> BaseEntity | None:
627630
return entity
628631

629632
@override
630-
async def bulk_create(self, entities: list[BaseEntity]) -> list[BaseEntity] | None:
633+
async def bulk_create(self, entities: list[T]) -> list[T] | None:
631634
"""Creates multiple entities in a single database operation.
632635
633636
Args:
634637
entities: List of entities to create.
635638
636639
Returns:
637-
List of created entities with updated attributes.
640+
List of created entities with updated attributes, preserving original types.
638641
639642
Raises:
640643
InvalidEntityTypeError: If any entity is not a valid SQLAlchemy model.
@@ -661,7 +664,7 @@ async def bulk_create(self, entities: list[BaseEntity]) -> list[BaseEntity] | No
661664
return entities
662665

663666
@override
664-
async def get_by_uuid(self, entity_type: type, entity_uuid: UUID) -> Any | None:
667+
async def get_by_uuid(self, entity_type: type[T], entity_uuid: UUID) -> T | None:
665668
"""Retrieve an entity by its UUID.
666669
667670
Args:

0 commit comments

Comments
 (0)