Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit cc3246a

Browse files
goteguruMészáros Gergelyaminalaee
authored
fix concurrency of parallel transactions (#328)
Co-authored-by: Mészáros Gergely <monk@geotronic.hu> Co-authored-by: Amin Alaee <mohammadamin.alaee@gmail.com>
1 parent a45d42f commit cc3246a

4 files changed

Lines changed: 56 additions & 11 deletions

File tree

databases/core.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66
import typing
77
from types import TracebackType
8-
from urllib.parse import SplitResult, parse_qsl, urlsplit, unquote
8+
from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit
99

1010
from sqlalchemy import text
1111
from sqlalchemy.sql import ClauseElement
@@ -14,9 +14,9 @@
1414
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
1515

1616
if sys.version_info >= (3, 7): # pragma: no cover
17-
from contextvars import ContextVar
17+
import contextvars as contextvars
1818
else: # pragma: no cover
19-
from aiocontextvars import ContextVar
19+
import aiocontextvars as contextvars
2020

2121
try: # pragma: no cover
2222
import click
@@ -68,7 +68,9 @@ def __init__(
6868
self._backend = backend_cls(self.url, **self.options)
6969

7070
# Connections are stored as task-local state.
71-
self._connection_context = ContextVar("connection_context") # type: ContextVar
71+
self._connection_context = contextvars.ContextVar(
72+
"connection_context"
73+
) # type: contextvars.ContextVar
7274

7375
# When `force_rollback=True` is used, we use a single global
7476
# connection, within a transaction that always rolls back.
@@ -173,21 +175,35 @@ async def iterate(
173175
async for record in connection.iterate(query, values):
174176
yield record
175177

178+
def _new_connection(self) -> "Connection":
179+
connection = Connection(self._backend)
180+
self._connection_context.set(connection)
181+
return connection
182+
176183
def connection(self) -> "Connection":
177184
if self._global_connection is not None:
178185
return self._global_connection
179186

180187
try:
181188
return self._connection_context.get()
182189
except LookupError:
183-
connection = Connection(self._backend)
184-
self._connection_context.set(connection)
185-
return connection
190+
return self._new_connection()
186191

187192
def transaction(
188193
self, *, force_rollback: bool = False, **kwargs: typing.Any
189194
) -> "Transaction":
190-
return Transaction(self.connection, force_rollback=force_rollback, **kwargs)
195+
try:
196+
connection = self._connection_context.get()
197+
is_root = not connection._transaction_stack
198+
if is_root:
199+
newcontext = contextvars.copy_context()
200+
get_conn = lambda: newcontext.run(self._new_connection)
201+
else:
202+
get_conn = self.connection
203+
except LookupError:
204+
get_conn = self.connection
205+
206+
return Transaction(get_conn, force_rollback=force_rollback, **kwargs)
191207

192208
@contextlib.contextmanager
193209
def force_rollback(self) -> typing.Iterator[None]:

docs/connections_and_transactions.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,21 @@ database = Database('postgresql://localhost/example', ssl=True, min_size=5, max_
5959

6060
## Transactions
6161

62-
Transactions are managed by async context blocks:
62+
Transactions are managed by async context blocks.
63+
64+
A transaction can be acquired from the database connection pool:
6365

6466
```python
6567
async with database.transaction():
6668
...
6769
```
70+
It can also be acquired from a specific database connection:
71+
72+
```python
73+
async with database.connection() as connection:
74+
async with connection.transaction():
75+
...
76+
```
6877

6978
For a lower-level transaction API:
7079

tests/test_database_url.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from databases import DatabaseURL
21
from urllib.parse import quote
2+
33
import pytest
44

5+
from databases import DatabaseURL
6+
57

68
def test_database_url_repr():
79
u = DatabaseURL("postgresql://localhost/name")

tests/test_databases.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ async def test_queries_with_expose_backend_connection(database_url):
800800
"""
801801
async with Database(database_url) as database:
802802
async with database.connection() as connection:
803-
async with database.transaction(force_rollback=True):
803+
async with connection.transaction(force_rollback=True):
804804
# Get the raw connection
805805
raw_connection = connection.raw_connection
806806

@@ -996,3 +996,21 @@ async def test_column_names(database_url, select_query):
996996
assert sorted(results[0].keys()) == ["completed", "id", "text"]
997997
assert results[0]["text"] == "example1"
998998
assert results[0]["completed"] == True
999+
1000+
1001+
@pytest.mark.parametrize("database_url", DATABASE_URLS)
1002+
@async_adapter
1003+
async def test_parallel_transactions(database_url):
1004+
"""
1005+
Test parallel transaction execution.
1006+
"""
1007+
1008+
async def test_task(db):
1009+
async with db.transaction():
1010+
await db.fetch_one("SELECT 1")
1011+
1012+
async with Database(database_url) as database:
1013+
await database.fetch_one("SELECT 1")
1014+
1015+
tasks = [test_task(database) for i in range(4)]
1016+
await asyncio.gather(*tasks)

0 commit comments

Comments
 (0)