|
5 | 5 | import sys |
6 | 6 | import typing |
7 | 7 | from types import TracebackType |
8 | | -from urllib.parse import SplitResult, parse_qsl, urlsplit, unquote |
| 8 | +from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit |
9 | 9 |
|
10 | 10 | from sqlalchemy import text |
11 | 11 | from sqlalchemy.sql import ClauseElement |
|
14 | 14 | from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend |
15 | 15 |
|
16 | 16 | if sys.version_info >= (3, 7): # pragma: no cover |
17 | | - from contextvars import ContextVar |
| 17 | + import contextvars as contextvars |
18 | 18 | else: # pragma: no cover |
19 | | - from aiocontextvars import ContextVar |
| 19 | + import aiocontextvars as contextvars |
20 | 20 |
|
21 | 21 | try: # pragma: no cover |
22 | 22 | import click |
@@ -68,7 +68,9 @@ def __init__( |
68 | 68 | self._backend = backend_cls(self.url, **self.options) |
69 | 69 |
|
70 | 70 | # Connections are stored as task-local state. |
71 | | - self._connection_context = ContextVar("connection_context") # type: ContextVar |
| 71 | + self._connection_context = contextvars.ContextVar( |
| 72 | + "connection_context" |
| 73 | + ) # type: contextvars.ContextVar |
72 | 74 |
|
73 | 75 | # When `force_rollback=True` is used, we use a single global |
74 | 76 | # connection, within a transaction that always rolls back. |
@@ -173,21 +175,35 @@ async def iterate( |
173 | 175 | async for record in connection.iterate(query, values): |
174 | 176 | yield record |
175 | 177 |
|
| 178 | + def _new_connection(self) -> "Connection": |
| 179 | + connection = Connection(self._backend) |
| 180 | + self._connection_context.set(connection) |
| 181 | + return connection |
| 182 | + |
176 | 183 | def connection(self) -> "Connection": |
177 | 184 | if self._global_connection is not None: |
178 | 185 | return self._global_connection |
179 | 186 |
|
180 | 187 | try: |
181 | 188 | return self._connection_context.get() |
182 | 189 | except LookupError: |
183 | | - connection = Connection(self._backend) |
184 | | - self._connection_context.set(connection) |
185 | | - return connection |
| 190 | + return self._new_connection() |
186 | 191 |
|
187 | 192 | def transaction( |
188 | 193 | self, *, force_rollback: bool = False, **kwargs: typing.Any |
189 | 194 | ) -> "Transaction": |
190 | | - return Transaction(self.connection, force_rollback=force_rollback, **kwargs) |
| 195 | + try: |
| 196 | + connection = self._connection_context.get() |
| 197 | + is_root = not connection._transaction_stack |
| 198 | + if is_root: |
| 199 | + newcontext = contextvars.copy_context() |
| 200 | + get_conn = lambda: newcontext.run(self._new_connection) |
| 201 | + else: |
| 202 | + get_conn = self.connection |
| 203 | + except LookupError: |
| 204 | + get_conn = self.connection |
| 205 | + |
| 206 | + return Transaction(get_conn, force_rollback=force_rollback, **kwargs) |
191 | 207 |
|
192 | 208 | @contextlib.contextmanager |
193 | 209 | def force_rollback(self) -> typing.Iterator[None]: |
|
0 commit comments