Skip to content

Commit 93525bb

Browse files
Merge branch master of github.com:SyntaxArc/ArchiPy
# Conflicts: # poetry.lock
2 parents 1ed10e2 + 61c1c85 commit 93525bb

10 files changed

Lines changed: 456 additions & 847 deletions

File tree

archipy/adapters/base/sqlalchemy/adapters.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import Any, override
2+
from typing import Any, override, TypeVar
33
from uuid import UUID
44

55
from sqlalchemy import Delete, Executable, Result, ScalarResult, Update, func, select
@@ -30,6 +30,9 @@
3030
from archipy.models.types.base_types import FilterOperationType
3131
from archipy.models.types.sort_order_type import SortOrderType
3232

33+
# Generic type variable for BaseEntity subclasses
34+
T = TypeVar('T', bound=BaseEntity)
35+
3336

3437
class SQLAlchemyExceptionHandlerMixin:
3538
"""Mixin providing centralized exception handling for SQLAlchemy operations.
@@ -284,14 +287,14 @@ def get_session(self) -> Session:
284287
return self.session_manager.get_session()
285288

286289
@override
287-
def create(self, entity: BaseEntity) -> BaseEntity | None:
290+
def create(self, entity: T) -> T | None:
288291
"""Create a new entity in the database.
289292
290293
Args:
291294
entity: The entity to create.
292295
293296
Returns:
294-
The created entity with updated attributes.
297+
The created entity with updated attributes, preserving the original type.
295298
296299
Raises:
297300
InvalidEntityTypeError: If the entity type is not a valid SQLAlchemy model.
@@ -318,14 +321,14 @@ def create(self, entity: BaseEntity) -> BaseEntity | None:
318321
return entity
319322

320323
@override
321-
def bulk_create(self, entities: list[BaseEntity]) -> list[BaseEntity] | None:
324+
def bulk_create(self, entities: list[T]) -> list[T] | None:
322325
"""Creates multiple entities in a single database operation.
323326
324327
Args:
325328
entities: List of entities to create.
326329
327330
Returns:
328-
List of created entities with updated attributes.
331+
List of created entities with updated attributes, preserving original types.
329332
330333
Raises:
331334
InvalidEntityTypeError: If any entity is not a valid SQLAlchemy model.
@@ -352,7 +355,7 @@ def bulk_create(self, entities: list[BaseEntity]) -> list[BaseEntity] | None:
352355
return entities
353356

354357
@override
355-
def get_by_uuid(self, entity_type: type, entity_uuid: UUID) -> BaseEntity | None:
358+
def get_by_uuid(self, entity_type: type[T], entity_uuid: UUID) -> T | None:
356359
"""Retrieve an entity by its UUID.
357360
358361
Args:
@@ -593,14 +596,14 @@ def get_session(self) -> AsyncSession:
593596
return self.session_manager.get_session()
594597

595598
@override
596-
async def create(self, entity: BaseEntity) -> BaseEntity | None:
599+
async def create(self, entity: T) -> T | None:
597600
"""Create a new entity in the database.
598601
599602
Args:
600603
entity: The entity to create.
601604
602605
Returns:
603-
The created entity with updated attributes.
606+
The created entity with updated attributes, preserving the original type.
604607
605608
Raises:
606609
InvalidEntityTypeError: If the entity type is not a valid SQLAlchemy model.
@@ -627,14 +630,14 @@ async def create(self, entity: BaseEntity) -> BaseEntity | None:
627630
return entity
628631

629632
@override
630-
async def bulk_create(self, entities: list[BaseEntity]) -> list[BaseEntity] | None:
633+
async def bulk_create(self, entities: list[T]) -> list[T] | None:
631634
"""Creates multiple entities in a single database operation.
632635
633636
Args:
634637
entities: List of entities to create.
635638
636639
Returns:
637-
List of created entities with updated attributes.
640+
List of created entities with updated attributes, preserving original types.
638641
639642
Raises:
640643
InvalidEntityTypeError: If any entity is not a valid SQLAlchemy model.
@@ -661,7 +664,7 @@ async def bulk_create(self, entities: list[BaseEntity]) -> list[BaseEntity] | No
661664
return entities
662665

663666
@override
664-
async def get_by_uuid(self, entity_type: type, entity_uuid: UUID) -> Any | None:
667+
async def get_by_uuid(self, entity_type: type[T], entity_uuid: UUID) -> T | None:
665668
"""Retrieve an entity by its UUID.
666669
667670
Args:

features/environment.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,6 @@ def before_scenario(context: Context, scenario: Scenario):
6565
except Exception as e:
6666
logger.exception(f"Error setting test config: {e}")
6767

68-
# Set up async test environment if needed
69-
if "async" in scenario.name.lower() or any("async" in tag.lower() for tag in scenario.tags):
70-
logger.info("Setting up async test environment")
71-
try:
72-
# Create a new event loop for this scenario
73-
loop = asyncio.new_event_loop()
74-
asyncio.set_event_loop(loop)
75-
scenario_context.store("_async_test_loop", loop)
76-
except Exception as e:
77-
logger.exception(f"Error setting up async environment: {e}")
7868

7969

8070
def after_scenario(context: Context, scenario: Scenario):

features/scenario_context.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,30 +43,14 @@ def cleanup(self):
4343
# Clean up async adapter
4444
if self.async_adapter:
4545
try:
46-
if hasattr(self.async_adapter, "session_manager") and hasattr(
47-
self.async_adapter.session_manager,
48-
"engine",
49-
):
50-
# For async, we need to get a loop and run the coroutine
51-
loop = asyncio.get_event_loop()
52-
if not loop.is_closed():
53-
try:
54-
# Run the session removal coroutine
55-
loop.run_until_complete(self.async_adapter.session_manager.remove_session())
56-
57-
# Run the engine disposal coroutine
58-
loop.run_until_complete(self.async_adapter.session_manager.engine.dispose())
59-
except Exception as e:
60-
print(f"Error in async cleanup: {e}")
61-
else:
62-
# If the loop is closed, create a new one temporarily
63-
temp_loop = asyncio.new_event_loop()
64-
try:
65-
asyncio.set_event_loop(temp_loop)
66-
temp_loop.run_until_complete(self.async_adapter.session_manager.remove_session())
67-
temp_loop.run_until_complete(self.async_adapter.session_manager.engine.dispose())
68-
finally:
69-
temp_loop.close()
46+
# Try to run async cleanup if we're in an async context
47+
try:
48+
loop = asyncio.get_running_loop()
49+
# If we have a running loop, create a task
50+
asyncio.create_task(self.async_cleanup())
51+
except RuntimeError:
52+
# No running loop, run in new loop
53+
asyncio.run(self.async_cleanup())
7054
except Exception as e:
7155
print(f"Error in async cleanup: {e}")
7256

@@ -80,3 +64,16 @@ def cleanup(self):
8064
os.remove(self.db_file)
8165
except Exception as e:
8266
print(f"Error removing database file: {e}")
67+
68+
async def async_cleanup(self):
69+
"""Clean up async resources associated with this scenario."""
70+
if self.async_adapter:
71+
try:
72+
if hasattr(self.async_adapter, "session_manager") and hasattr(
73+
self.async_adapter.session_manager, "engine"
74+
):
75+
# Clean up async sessions and engine
76+
await self.async_adapter.session_manager.remove_session()
77+
await self.async_adapter.session_manager.engine.dispose()
78+
except Exception as e:
79+
print(f"Error in async cleanup: {e}")

features/steps/atomic_transaction_steps.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
This module contains step definitions for both synchronous and asynchronous
44
atomic transaction scenarios.
55
"""
6-
6+
import asyncio
77
import logging
88
import os
99
import tempfile
@@ -16,12 +16,10 @@
1616
from features.test_entity import RelatedTestEntity, TestAdminEntity, TestEntity, TestManagerEntity
1717
from features.test_entity_factory import TestEntityFactory
1818
from features.test_helpers import (
19-
SafeAsyncContextManager,
2019
async_schema_setup,
2120
get_adapter,
2221
get_async_adapter,
2322
get_current_scenario_context,
24-
safe_run_async,
2523
)
2624
from sqlalchemy import select
2725

@@ -127,8 +125,7 @@ def step_given_database_initialized(context):
127125

128126
# Create schema with async adapter
129127
logger.info("Creating database schema with async adapter")
130-
with SafeAsyncContextManager(context) as ctx:
131-
ctx.run(async_schema_setup(async_adapter))
128+
asyncio.run(async_schema_setup(async_adapter))
132129

133130
logger.info("Async adapter and schema setup completed")
134131
except Exception as e:
@@ -965,7 +962,6 @@ def verify_consistency():
965962

966963

967964
@when("a new entity is created in an async atomic transaction")
968-
@safe_run_async
969965
async def step_when_entity_created_in_async_atomic(context):
970966
"""Create a new entity within an async atomic transaction."""
971967
logger = getattr(context, "logger", logging.getLogger("behave.steps"))
@@ -993,7 +989,6 @@ async def create_entity_async_atomic():
993989

994990

995991
@then("the async entity should be retrievable")
996-
@safe_run_async
997992
async def step_then_async_entity_should_be_retrievable(context):
998993
"""Verify the entity exists after async atomic transaction."""
999994
logger = getattr(context, "logger", logging.getLogger("behave.steps"))
@@ -1016,7 +1011,6 @@ async def get_entity():
10161011

10171012

10181013
@when("a new async entity creation fails within an atomic transaction")
1019-
@safe_run_async
10201014
async def step_when_async_entity_creation_fails(context):
10211015
"""Attempt to create an async entity with a failure that causes rollback."""
10221016
logger = getattr(context, "logger", logging.getLogger("behave.steps"))
@@ -1053,7 +1047,6 @@ async def create_entity_with_failure():
10531047

10541048

10551049
@then("no async entity should exist in the database")
1056-
@safe_run_async
10571050
async def step_then_no_async_entity_should_exist(context):
10581051
"""Verify the entity doesn't exist after failed async atomic transaction."""
10591052
logger = getattr(context, "logger", logging.getLogger("behave.steps"))
@@ -1075,7 +1068,6 @@ async def check_entity_absence():
10751068

10761069

10771070
@then("the async database session should remain usable")
1078-
@safe_run_async
10791071
async def step_then_async_session_should_remain_usable(context):
10801072
"""Verify the async session is still usable after a failed transaction."""
10811073
logger = getattr(context, "logger", logging.getLogger("behave.steps"))
@@ -1107,7 +1099,6 @@ async def verify_session_usable():
11071099

11081100

11091101
@when("multiple entities are created in an async atomic transaction")
1110-
@safe_run_async
11111102
async def step_when_multiple_async_entities_created(context):
11121103
"""Create multiple entities in a single async atomic transaction."""
11131104
logger = getattr(context, "logger", logging.getLogger("behave.steps"))
@@ -1141,7 +1132,6 @@ async def create_multiple_entities():
11411132

11421133

11431134
@then("all async entities should be retrievable")
1144-
@safe_run_async
11451135
async def step_then_all_async_entities_retrievable(context):
11461136
"""Verify all entities exist after async atomic transaction."""
11471137
logger = getattr(context, "logger", logging.getLogger("behave.steps"))
@@ -1177,7 +1167,6 @@ async def verify_entities():
11771167

11781168

11791169
@when("complex async operations are performed in a transaction")
1180-
@safe_run_async
11811170
async def step_when_complex_async_operations(context):
11821171
"""Demonstrate more complex async operations with proper session management."""
11831172
logger = getattr(context, "logger", logging.getLogger("behave.steps"))
@@ -1227,7 +1216,6 @@ async def create_entity_with_relations():
12271216

12281217

12291218
@then("all related entities should be accessible")
1230-
@safe_run_async
12311219
async def step_then_related_entities_accessible(context):
12321220
"""Verify that related entities can be accessed through relationships."""
12331221
logger = getattr(context, "logger", logging.getLogger("behave.steps"))

0 commit comments

Comments
 (0)