Skip to content

Commit f964464

Browse files
fix(adapters): resolve all MyPy and Ruff errors in sqlalchemy adapters
- Fix missing return statements in base SQLAlchemy adapters by adding raise statements after exception handling - Add proper type annotations for SMTP connection and datetime fields in email adapters - Handle None values for SMTP configuration fields with proper validation - Fix type narrowing in attachment processing for different source types - Add MyPy override for elasticsearch adapters to handle kwargs type issues - Fix return type issues in search query results with type ignore comments - Resolve all type checking errors while maintaining backward compatibility
1 parent 9af6307 commit f964464

15 files changed

Lines changed: 161 additions & 117 deletions

archipy/adapters/base/sqlalchemy/adapters.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
# Generic type variable for BaseEntity subclasses
3434
T = TypeVar("T", bound=BaseEntity)
35+
ConfigT = TypeVar("ConfigT", bound=SQLAlchemyConfig)
3536

3637

3738
class SQLAlchemyExceptionHandlerMixin:
@@ -82,7 +83,7 @@ class SQLAlchemyFilterMixin:
8283
def _apply_filter(
8384
query: Select | Update | Delete,
8485
field: InstrumentedAttribute,
85-
value: Any,
86+
value: str | int | float | bool | list | None,
8687
operation: FilterOperationType,
8788
) -> Select | Update | Delete:
8889
"""Apply a filter to a SQLAlchemy query based on the specified operation.
@@ -111,8 +112,12 @@ def _apply_filter(
111112
case FilterOperationType.GREATER_THAN_OR_EQUAL:
112113
return query.where(field >= value)
113114
case FilterOperationType.IN_LIST:
115+
if not isinstance(value, list):
116+
raise InvalidArgumentError(f"IN_LIST operation requires a list, got {type(value)}")
114117
return query.where(field.in_(value))
115118
case FilterOperationType.NOT_IN_LIST:
119+
if not isinstance(value, list):
120+
raise InvalidArgumentError(f"NOT_IN_LIST operation requires a list, got {type(value)}")
116121
return query.where(~field.in_(value))
117122
case FilterOperationType.LIKE:
118123
return query.where(field.like(f"%{value}%"))
@@ -193,7 +198,7 @@ def _apply_sorting(entity: type[BaseEntity], query: Select, sort_info: SortDTO |
193198
raise InvalidArgumentError(argument_name="sort_info.order")
194199

195200

196-
class BaseSQLAlchemyAdapter(
201+
class BaseSQLAlchemyAdapter[ConfigT: SQLAlchemyConfig](
197202
SQLAlchemyPort,
198203
SQLAlchemyPaginationMixin,
199204
SQLAlchemySortMixin,
@@ -209,16 +214,17 @@ class BaseSQLAlchemyAdapter(
209214
orm_config: Configuration for SQLAlchemy. If None, uses global config.
210215
"""
211216

212-
def __init__(self, orm_config: SQLAlchemyConfig | None = None) -> None:
217+
def __init__(self, orm_config: ConfigT | None = None) -> None:
213218
"""Initialize the base adapter with a session manager.
214219
215220
Args:
216221
orm_config: Configuration for SQLAlchemy. If None, uses global config.
217222
"""
218223
configs = BaseConfig.global_config().SQLALCHEMY if orm_config is None else orm_config
219-
self.session_manager: BaseSQLAlchemySessionManager = self._create_session_manager(configs)
224+
# Cast to ConfigT since subclasses will ensure the proper type
225+
self.session_manager: BaseSQLAlchemySessionManager[ConfigT] = self._create_session_manager(configs) # type: ignore[arg-type]
220226

221-
def _create_session_manager(self, configs: SQLAlchemyConfig) -> BaseSQLAlchemySessionManager:
227+
def _create_session_manager(self, configs: ConfigT) -> BaseSQLAlchemySessionManager[ConfigT]:
222228
"""Create a session manager for the specific database.
223229
224230
Args:
@@ -227,7 +233,7 @@ def _create_session_manager(self, configs: SQLAlchemyConfig) -> BaseSQLAlchemySe
227233
Returns:
228234
A session manager instance.
229235
"""
230-
return BaseSQLAlchemySessionManager(configs)
236+
raise NotImplementedError("Subclasses must implement _create_session_manager")
231237

232238
@override
233239
def execute_search_query(
@@ -263,15 +269,16 @@ def execute_search_query(
263269
paginated_query = self._apply_pagination(sorted_query, pagination)
264270
result_set = session.execute(paginated_query)
265271
if has_multiple_entities:
266-
results = result_set.fetchall()
272+
results = list(result_set.fetchall())
267273
else:
268-
results = result_set.scalars().all()
274+
results = list(result_set.scalars().all())
269275
count_query = select(func.count()).select_from(query.subquery())
270276
total_count = session.execute(count_query).scalar_one()
271277
except Exception as e:
272278
self._handle_db_exception(e, self.session_manager._get_database_name())
279+
raise # This will never be reached, but satisfies MyPy
273280
else:
274-
return results, total_count
281+
return results, total_count # type: ignore[return-value]
275282

276283
@override
277284
def get_session(self) -> Session:
@@ -317,6 +324,7 @@ def create(self, entity: T) -> T | None:
317324
session.flush()
318325
except Exception as e:
319326
self._handle_db_exception(e, self.session_manager._get_database_name())
327+
raise # This will never be reached, but satisfies MyPy
320328
else:
321329
return entity
322330

@@ -351,6 +359,7 @@ def bulk_create(self, entities: list[T]) -> list[T] | None:
351359
session.flush()
352360
except Exception as e:
353361
self._handle_db_exception(e, self.session_manager._get_database_name())
362+
raise # This will never be reached, but satisfies MyPy
354363
else:
355364
return entities
356365

@@ -384,6 +393,7 @@ def get_by_uuid(self, entity_type: type[T], entity_uuid: UUID) -> T | None:
384393
result = session.get(entity_type, entity_uuid)
385394
except Exception as e:
386395
self._handle_db_exception(e, self.session_manager._get_database_name())
396+
raise # This will never be reached, but satisfies MyPy
387397
else:
388398
return result
389399

@@ -415,8 +425,6 @@ def delete(self, entity: BaseEntity) -> None:
415425
session.flush()
416426
except Exception as e:
417427
self._handle_db_exception(e, self.session_manager._get_database_name())
418-
else:
419-
return ...
420428

421429
@override
422430
def bulk_delete(self, entities: list[BaseEntity]) -> None:
@@ -447,8 +455,6 @@ def bulk_delete(self, entities: list[BaseEntity]) -> None:
447455
session.flush()
448456
except Exception as e:
449457
self._handle_db_exception(e, self.session_manager._get_database_name())
450-
else:
451-
return ...
452458

453459
@override
454460
def execute(self, statement: Executable, params: AnyExecuteParams | None = None) -> Result[Any]:
@@ -472,6 +478,7 @@ def execute(self, statement: Executable, params: AnyExecuteParams | None = None)
472478
result = session.execute(statement, params or {})
473479
except Exception as e:
474480
self._handle_db_exception(e, self.session_manager._get_database_name())
481+
raise # This will never be reached, but satisfies MyPy
475482
else:
476483
return result
477484

@@ -497,11 +504,12 @@ def scalars(self, statement: Executable, params: AnyExecuteParams | None = None)
497504
result = session.scalars(statement, params or {})
498505
except Exception as e:
499506
self._handle_db_exception(e, self.session_manager._get_database_name())
507+
raise # This will never be reached, but satisfies MyPy
500508
else:
501509
return result
502510

503511

504-
class AsyncBaseSQLAlchemyAdapter(
512+
class AsyncBaseSQLAlchemyAdapter[ConfigT: SQLAlchemyConfig](
505513
AsyncSQLAlchemyPort,
506514
SQLAlchemyPaginationMixin,
507515
SQLAlchemySortMixin,
@@ -517,16 +525,17 @@ class AsyncBaseSQLAlchemyAdapter(
517525
orm_config: Configuration for SQLAlchemy. If None, uses global config.
518526
"""
519527

520-
def __init__(self, orm_config: SQLAlchemyConfig | None = None) -> None:
528+
def __init__(self, orm_config: ConfigT | None = None) -> None:
521529
"""Initialize the base async adapter with a session manager.
522530
523531
Args:
524532
orm_config: Configuration for SQLAlchemy. If None, uses global config.
525533
"""
526534
configs = BaseConfig.global_config().SQLALCHEMY if orm_config is None else orm_config
527-
self.session_manager: AsyncBaseSQLAlchemySessionManager = self._create_async_session_manager(configs)
535+
# Cast to ConfigT since subclasses will ensure the proper type
536+
self.session_manager: AsyncBaseSQLAlchemySessionManager[ConfigT] = self._create_async_session_manager(configs) # type: ignore[arg-type]
528537

529-
def _create_async_session_manager(self, configs: SQLAlchemyConfig) -> AsyncBaseSQLAlchemySessionManager:
538+
def _create_async_session_manager(self, configs: ConfigT) -> AsyncBaseSQLAlchemySessionManager[ConfigT]:
530539
"""Create an async session manager for the specific database.
531540
532541
Args:
@@ -535,7 +544,7 @@ def _create_async_session_manager(self, configs: SQLAlchemyConfig) -> AsyncBaseS
535544
Returns:
536545
An async session manager instance.
537546
"""
538-
return AsyncBaseSQLAlchemySessionManager(configs)
547+
raise NotImplementedError("Subclasses must implement _create_async_session_manager")
539548

540549
@override
541550
async def execute_search_query(
@@ -571,16 +580,17 @@ async def execute_search_query(
571580
paginated_query = self._apply_pagination(sorted_query, pagination)
572581
result_set = await session.execute(paginated_query)
573582
if has_multiple_entities:
574-
results = result_set.fetchall()
583+
results = list(result_set.fetchall())
575584
else:
576-
results = result_set.scalars().all()
585+
results = list(result_set.scalars().all())
577586
count_query = select(func.count()).select_from(query.subquery())
578-
total_count = await session.execute(count_query)
579-
total_count = total_count.scalar_one()
587+
total_count_result = await session.execute(count_query)
588+
total_count = total_count_result.scalar_one()
580589
except Exception as e:
581590
self._handle_db_exception(e, self.session_manager._get_database_name())
591+
raise # This will never be reached, but satisfies MyPy
582592
else:
583-
return results, total_count
593+
return results, total_count # type: ignore[return-value]
584594

585595
@override
586596
def get_session(self) -> AsyncSession:
@@ -626,6 +636,7 @@ async def create(self, entity: T) -> T | None:
626636
await session.flush()
627637
except Exception as e:
628638
self._handle_db_exception(e, self.session_manager._get_database_name())
639+
raise # This will never be reached, but satisfies MyPy
629640
else:
630641
return entity
631642

@@ -660,6 +671,7 @@ async def bulk_create(self, entities: list[T]) -> list[T] | None:
660671
await session.flush()
661672
except Exception as e:
662673
self._handle_db_exception(e, self.session_manager._get_database_name())
674+
raise # This will never be reached, but satisfies MyPy
663675
else:
664676
return entities
665677

@@ -693,6 +705,7 @@ async def get_by_uuid(self, entity_type: type[T], entity_uuid: UUID) -> T | None
693705
result = await session.get(entity_type, entity_uuid)
694706
except Exception as e:
695707
self._handle_db_exception(e, self.session_manager._get_database_name())
708+
raise # This will never be reached, but satisfies MyPy
696709
else:
697710
return result
698711

@@ -724,8 +737,6 @@ async def delete(self, entity: BaseEntity) -> None:
724737
await session.flush()
725738
except Exception as e:
726739
self._handle_db_exception(e, self.session_manager._get_database_name())
727-
else:
728-
return ...
729740

730741
@override
731742
async def bulk_delete(self, entities: list[BaseEntity]) -> None:
@@ -756,8 +767,6 @@ async def bulk_delete(self, entities: list[BaseEntity]) -> None:
756767
await session.flush()
757768
except Exception as e:
758769
self._handle_db_exception(e, self.session_manager._get_database_name())
759-
else:
760-
return ...
761770

762771
@override
763772
async def execute(self, statement: Executable, params: AnyExecuteParams | None = None) -> Result[Any]:
@@ -781,6 +790,7 @@ async def execute(self, statement: Executable, params: AnyExecuteParams | None =
781790
result = await session.execute(statement, params or {})
782791
except Exception as e:
783792
self._handle_db_exception(e, self.session_manager._get_database_name())
793+
raise # This will never be reached, but satisfies MyPy
784794
else:
785795
return result
786796

@@ -806,5 +816,6 @@ async def scalars(self, statement: Executable, params: AnyExecuteParams | None =
806816
result = await session.scalars(statement, params or {})
807817
except Exception as e:
808818
self._handle_db_exception(e, self.session_manager._get_database_name())
819+
raise # This will never be reached, but satisfies MyPy
809820
else:
810821
return result

archipy/adapters/base/sqlalchemy/ports.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any
44
from uuid import UUID
55

6-
from sqlalchemy import Executable, Select
6+
from sqlalchemy import Executable, Result, ScalarResult, Select
77
from sqlalchemy.ext.asyncio import AsyncSession
88
from sqlalchemy.orm import Session
99

@@ -115,7 +115,7 @@ def bulk_delete(self, entities: list[BaseEntity]) -> None:
115115
raise NotImplementedError
116116

117117
@abstractmethod
118-
def execute(self, statement: Executable, params: AnyExecuteParams | None = None) -> Any:
118+
def execute(self, statement: Executable, params: AnyExecuteParams | None = None) -> Result[Any]:
119119
"""Executes a raw SQL statement.
120120
121121
Args:
@@ -128,7 +128,7 @@ def execute(self, statement: Executable, params: AnyExecuteParams | None = None)
128128
raise NotImplementedError
129129

130130
@abstractmethod
131-
def scalars(self, statement: Executable, params: AnyExecuteParams | None = None) -> Any:
131+
def scalars(self, statement: Executable, params: AnyExecuteParams | None = None) -> ScalarResult[Any]:
132132
"""Executes a statement and returns the scalar result.
133133
134134
Args:
@@ -240,7 +240,7 @@ async def bulk_delete(self, entities: list[BaseEntity]) -> None:
240240
raise NotImplementedError
241241

242242
@abstractmethod
243-
async def execute(self, statement: Executable, params: AnyExecuteParams | None = None) -> Any:
243+
async def execute(self, statement: Executable, params: AnyExecuteParams | None = None) -> Result[Any]:
244244
"""Executes a raw SQL statement asynchronously.
245245
246246
Args:
@@ -253,7 +253,7 @@ async def execute(self, statement: Executable, params: AnyExecuteParams | None =
253253
raise NotImplementedError
254254

255255
@abstractmethod
256-
async def scalars(self, statement: Executable, params: AnyExecuteParams | None = None) -> Any:
256+
async def scalars(self, statement: Executable, params: AnyExecuteParams | None = None) -> ScalarResult[Any]:
257257
"""Executes a statement and returns the scalar result asynchronously.
258258
259259
Args:

archipy/adapters/base/sqlalchemy/session_manager_registry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING
1+
from typing import TYPE_CHECKING, ClassVar
22

33
from archipy.models.errors import (
44
InternalError,
@@ -25,8 +25,8 @@ class SessionManagerRegistry:
2525
>>> session = sync_manager.get_session()
2626
"""
2727

28-
_sync_instance = None
29-
_async_instance = None
28+
_sync_instance: ClassVar["SessionManagerPort | None"] = None
29+
_async_instance: ClassVar["AsyncSessionManagerPort | None"] = None
3030

3131
@classmethod
3232
def get_sync_manager(cls) -> "SessionManagerPort":

0 commit comments

Comments
 (0)