|
3 | 3 |
|
4 | 4 | from sqlalchemy import pool |
5 | 5 | from sqlalchemy.engine import Connection |
6 | | -from sqlalchemy.ext.asyncio import async_engine_from_config |
| 6 | +from sqlalchemy.ext.asyncio import async_engine_from_config, AsyncEngine |
| 7 | +from sqlalchemy import engine_from_config |
7 | 8 |
|
8 | 9 | from alembic import context |
9 | 10 |
|
@@ -58,34 +59,43 @@ def run_migrations_offline() -> None: |
58 | 59 |
|
59 | 60 |
|
60 | 61 | def do_run_migrations(connection: Connection) -> None: |
61 | | - context.configure(connection=connection, target_metadata=target_metadata) |
62 | | - |
| 62 | + context.configure( |
| 63 | + connection=connection, |
| 64 | + target_metadata=target_metadata, |
| 65 | + compare_type=True, |
| 66 | + ) |
63 | 67 | with context.begin_transaction(): |
64 | 68 | context.run_migrations() |
65 | 69 |
|
66 | 70 |
|
67 | | -async def run_async_migrations() -> None: |
| 71 | +async def run_async_migrations(connectable) -> None: |
68 | 72 | """In this scenario we need to create an Engine |
69 | 73 | and associate a connection with the context. |
70 | 74 |
|
71 | 75 | """ |
72 | | - |
73 | | - connectable = async_engine_from_config( |
74 | | - config.get_section(config.config_ini_section, {}), |
75 | | - prefix="sqlalchemy.", |
76 | | - poolclass=pool.NullPool, |
77 | | - ) |
78 | | - |
79 | 76 | async with connectable.connect() as connection: |
80 | 77 | await connection.run_sync(do_run_migrations) |
81 | | - |
82 | 78 | await connectable.dispose() |
83 | 79 |
|
84 | 80 |
|
85 | 81 | def run_migrations_online() -> None: |
86 | 82 | """Run migrations in 'online' mode.""" |
87 | | - |
88 | | - asyncio.run(run_async_migrations()) |
| 83 | + connectable = context.config.attributes.get('connection', None) |
| 84 | + if connectable is None: |
| 85 | + connectable = AsyncEngine( |
| 86 | + engine_from_config( |
| 87 | + context.config.get_section( |
| 88 | + context.config.config_ini_section, |
| 89 | + ), |
| 90 | + prefix='sqlalchemy.', |
| 91 | + poolclass=pool.NullPool, |
| 92 | + future=True, |
| 93 | + ) |
| 94 | + ) |
| 95 | + if isinstance(connectable, AsyncEngine): |
| 96 | + asyncio.run(run_async_migrations(connectable)) |
| 97 | + else: |
| 98 | + do_run_migrations(connectable) |
89 | 99 |
|
90 | 100 |
|
91 | 101 | if context.is_offline_mode(): |
|
0 commit comments