3232
3333# Generic type variable for BaseEntity subclasses
3434T = TypeVar ("T" , bound = BaseEntity )
35+ ConfigT = TypeVar ("ConfigT" , bound = SQLAlchemyConfig )
3536
3637
3738class SQLAlchemyExceptionHandlerMixin :
@@ -82,7 +83,7 @@ class SQLAlchemyFilterMixin:
8283 def _apply_filter (
8384 query : Select | Update | Delete ,
8485 field : InstrumentedAttribute ,
85- value : Any ,
86+ value : str | int | float | bool | list | None ,
8687 operation : FilterOperationType ,
8788 ) -> Select | Update | Delete :
8889 """Apply a filter to a SQLAlchemy query based on the specified operation.
@@ -111,8 +112,12 @@ def _apply_filter(
111112 case FilterOperationType .GREATER_THAN_OR_EQUAL :
112113 return query .where (field >= value )
113114 case FilterOperationType .IN_LIST :
115+ if not isinstance (value , list ):
116+ raise InvalidArgumentError (f"IN_LIST operation requires a list, got { type (value )} " )
114117 return query .where (field .in_ (value ))
115118 case FilterOperationType .NOT_IN_LIST :
119+ if not isinstance (value , list ):
120+ raise InvalidArgumentError (f"NOT_IN_LIST operation requires a list, got { type (value )} " )
116121 return query .where (~ field .in_ (value ))
117122 case FilterOperationType .LIKE :
118123 return query .where (field .like (f"%{ value } %" ))
@@ -193,7 +198,7 @@ def _apply_sorting(entity: type[BaseEntity], query: Select, sort_info: SortDTO |
193198 raise InvalidArgumentError (argument_name = "sort_info.order" )
194199
195200
196- class BaseSQLAlchemyAdapter (
201+ class BaseSQLAlchemyAdapter [ ConfigT : SQLAlchemyConfig ] (
197202 SQLAlchemyPort ,
198203 SQLAlchemyPaginationMixin ,
199204 SQLAlchemySortMixin ,
@@ -209,16 +214,17 @@ class BaseSQLAlchemyAdapter(
209214 orm_config: Configuration for SQLAlchemy. If None, uses global config.
210215 """
211216
212- def __init__ (self , orm_config : SQLAlchemyConfig | None = None ) -> None :
217+ def __init__ (self , orm_config : ConfigT | None = None ) -> None :
213218 """Initialize the base adapter with a session manager.
214219
215220 Args:
216221 orm_config: Configuration for SQLAlchemy. If None, uses global config.
217222 """
218223 configs = BaseConfig .global_config ().SQLALCHEMY if orm_config is None else orm_config
219- self .session_manager : BaseSQLAlchemySessionManager = self ._create_session_manager (configs )
224+ # Cast to ConfigT since subclasses will ensure the proper type
225+ self .session_manager : BaseSQLAlchemySessionManager [ConfigT ] = self ._create_session_manager (configs ) # type: ignore[arg-type]
220226
221- def _create_session_manager (self , configs : SQLAlchemyConfig ) -> BaseSQLAlchemySessionManager :
227+ def _create_session_manager (self , configs : ConfigT ) -> BaseSQLAlchemySessionManager [ ConfigT ] :
222228 """Create a session manager for the specific database.
223229
224230 Args:
@@ -227,7 +233,7 @@ def _create_session_manager(self, configs: SQLAlchemyConfig) -> BaseSQLAlchemySe
227233 Returns:
228234 A session manager instance.
229235 """
230- return BaseSQLAlchemySessionManager ( configs )
236+ raise NotImplementedError ( "Subclasses must implement _create_session_manager" )
231237
232238 @override
233239 def execute_search_query (
@@ -263,15 +269,16 @@ def execute_search_query(
263269 paginated_query = self ._apply_pagination (sorted_query , pagination )
264270 result_set = session .execute (paginated_query )
265271 if has_multiple_entities :
266- results = result_set .fetchall ()
272+ results = list ( result_set .fetchall () )
267273 else :
268- results = result_set .scalars ().all ()
274+ results = list ( result_set .scalars ().all () )
269275 count_query = select (func .count ()).select_from (query .subquery ())
270276 total_count = session .execute (count_query ).scalar_one ()
271277 except Exception as e :
272278 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
279+ raise # This will never be reached, but satisfies MyPy
273280 else :
274- return results , total_count
281+ return results , total_count # type: ignore[return-value]
275282
276283 @override
277284 def get_session (self ) -> Session :
@@ -317,6 +324,7 @@ def create(self, entity: T) -> T | None:
317324 session .flush ()
318325 except Exception as e :
319326 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
327+ raise # This will never be reached, but satisfies MyPy
320328 else :
321329 return entity
322330
@@ -351,6 +359,7 @@ def bulk_create(self, entities: list[T]) -> list[T] | None:
351359 session .flush ()
352360 except Exception as e :
353361 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
362+ raise # This will never be reached, but satisfies MyPy
354363 else :
355364 return entities
356365
@@ -384,6 +393,7 @@ def get_by_uuid(self, entity_type: type[T], entity_uuid: UUID) -> T | None:
384393 result = session .get (entity_type , entity_uuid )
385394 except Exception as e :
386395 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
396+ raise # This will never be reached, but satisfies MyPy
387397 else :
388398 return result
389399
@@ -415,8 +425,6 @@ def delete(self, entity: BaseEntity) -> None:
415425 session .flush ()
416426 except Exception as e :
417427 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
418- else :
419- return ...
420428
421429 @override
422430 def bulk_delete (self , entities : list [BaseEntity ]) -> None :
@@ -447,8 +455,6 @@ def bulk_delete(self, entities: list[BaseEntity]) -> None:
447455 session .flush ()
448456 except Exception as e :
449457 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
450- else :
451- return ...
452458
453459 @override
454460 def execute (self , statement : Executable , params : AnyExecuteParams | None = None ) -> Result [Any ]:
@@ -472,6 +478,7 @@ def execute(self, statement: Executable, params: AnyExecuteParams | None = None)
472478 result = session .execute (statement , params or {})
473479 except Exception as e :
474480 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
481+ raise # This will never be reached, but satisfies MyPy
475482 else :
476483 return result
477484
@@ -497,11 +504,12 @@ def scalars(self, statement: Executable, params: AnyExecuteParams | None = None)
497504 result = session .scalars (statement , params or {})
498505 except Exception as e :
499506 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
507+ raise # This will never be reached, but satisfies MyPy
500508 else :
501509 return result
502510
503511
504- class AsyncBaseSQLAlchemyAdapter (
512+ class AsyncBaseSQLAlchemyAdapter [ ConfigT : SQLAlchemyConfig ] (
505513 AsyncSQLAlchemyPort ,
506514 SQLAlchemyPaginationMixin ,
507515 SQLAlchemySortMixin ,
@@ -517,16 +525,17 @@ class AsyncBaseSQLAlchemyAdapter(
517525 orm_config: Configuration for SQLAlchemy. If None, uses global config.
518526 """
519527
520- def __init__ (self , orm_config : SQLAlchemyConfig | None = None ) -> None :
528+ def __init__ (self , orm_config : ConfigT | None = None ) -> None :
521529 """Initialize the base async adapter with a session manager.
522530
523531 Args:
524532 orm_config: Configuration for SQLAlchemy. If None, uses global config.
525533 """
526534 configs = BaseConfig .global_config ().SQLALCHEMY if orm_config is None else orm_config
527- self .session_manager : AsyncBaseSQLAlchemySessionManager = self ._create_async_session_manager (configs )
535+ # Cast to ConfigT since subclasses will ensure the proper type
536+ self .session_manager : AsyncBaseSQLAlchemySessionManager [ConfigT ] = self ._create_async_session_manager (configs ) # type: ignore[arg-type]
528537
529- def _create_async_session_manager (self , configs : SQLAlchemyConfig ) -> AsyncBaseSQLAlchemySessionManager :
538+ def _create_async_session_manager (self , configs : ConfigT ) -> AsyncBaseSQLAlchemySessionManager [ ConfigT ] :
530539 """Create an async session manager for the specific database.
531540
532541 Args:
@@ -535,7 +544,7 @@ def _create_async_session_manager(self, configs: SQLAlchemyConfig) -> AsyncBaseS
535544 Returns:
536545 An async session manager instance.
537546 """
538- return AsyncBaseSQLAlchemySessionManager ( configs )
547+ raise NotImplementedError ( "Subclasses must implement _create_async_session_manager" )
539548
540549 @override
541550 async def execute_search_query (
@@ -571,16 +580,17 @@ async def execute_search_query(
571580 paginated_query = self ._apply_pagination (sorted_query , pagination )
572581 result_set = await session .execute (paginated_query )
573582 if has_multiple_entities :
574- results = result_set .fetchall ()
583+ results = list ( result_set .fetchall () )
575584 else :
576- results = result_set .scalars ().all ()
585+ results = list ( result_set .scalars ().all () )
577586 count_query = select (func .count ()).select_from (query .subquery ())
578- total_count = await session .execute (count_query )
579- total_count = total_count .scalar_one ()
587+ total_count_result = await session .execute (count_query )
588+ total_count = total_count_result .scalar_one ()
580589 except Exception as e :
581590 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
591+ raise # This will never be reached, but satisfies MyPy
582592 else :
583- return results , total_count
593+ return results , total_count # type: ignore[return-value]
584594
585595 @override
586596 def get_session (self ) -> AsyncSession :
@@ -626,6 +636,7 @@ async def create(self, entity: T) -> T | None:
626636 await session .flush ()
627637 except Exception as e :
628638 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
639+ raise # This will never be reached, but satisfies MyPy
629640 else :
630641 return entity
631642
@@ -660,6 +671,7 @@ async def bulk_create(self, entities: list[T]) -> list[T] | None:
660671 await session .flush ()
661672 except Exception as e :
662673 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
674+ raise # This will never be reached, but satisfies MyPy
663675 else :
664676 return entities
665677
@@ -693,6 +705,7 @@ async def get_by_uuid(self, entity_type: type[T], entity_uuid: UUID) -> T | None
693705 result = await session .get (entity_type , entity_uuid )
694706 except Exception as e :
695707 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
708+ raise # This will never be reached, but satisfies MyPy
696709 else :
697710 return result
698711
@@ -724,8 +737,6 @@ async def delete(self, entity: BaseEntity) -> None:
724737 await session .flush ()
725738 except Exception as e :
726739 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
727- else :
728- return ...
729740
730741 @override
731742 async def bulk_delete (self , entities : list [BaseEntity ]) -> None :
@@ -756,8 +767,6 @@ async def bulk_delete(self, entities: list[BaseEntity]) -> None:
756767 await session .flush ()
757768 except Exception as e :
758769 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
759- else :
760- return ...
761770
762771 @override
763772 async def execute (self , statement : Executable , params : AnyExecuteParams | None = None ) -> Result [Any ]:
@@ -781,6 +790,7 @@ async def execute(self, statement: Executable, params: AnyExecuteParams | None =
781790 result = await session .execute (statement , params or {})
782791 except Exception as e :
783792 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
793+ raise # This will never be reached, but satisfies MyPy
784794 else :
785795 return result
786796
@@ -806,5 +816,6 @@ async def scalars(self, statement: Executable, params: AnyExecuteParams | None =
806816 result = await session .scalars (statement , params or {})
807817 except Exception as e :
808818 self ._handle_db_exception (e , self .session_manager ._get_database_name ())
819+ raise # This will never be reached, but satisfies MyPy
809820 else :
810821 return result
0 commit comments