Skip to content

Commit 1dcfbde

Browse files
rohansen856pre-commit-ci[bot]PGijsbers
authored
[ENH] use async mysql db driver and httpclients (#243)
Solves: #229 Replacing the current implementation of `mysqlclient` with `aiomysql` for async db connection support. would need to replace the db drivers under `database/` and `routes/` with async implementation and convert the functions to `async def` for async support --------- Signed-off-by: rohansen856 <rohansen856@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Pieter Gijsbers <p.gijsbers@tue.nl>
1 parent 33e86bc commit 1dcfbde

39 files changed

Lines changed: 834 additions & 609 deletions

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ classifiers = [
1414
]
1515
dependencies = [
1616
"fastapi",
17+
"cryptography",
1718
"pydantic",
1819
"uvicorn",
1920
"sqlalchemy",
2021
"mysqlclient",
22+
"aiomysql",
2123
"python_dotenv",
2224
"xmltodict",
2325
]
@@ -28,6 +30,7 @@ dev = [
2830
"pre-commit",
2931
"pytest",
3032
"pytest-mock",
33+
"pytest-asyncio",
3134
"httpx",
3235
"hypothesis",
3336
"deepdiff",
@@ -110,6 +113,9 @@ plugins = [
110113
pythonpath = [
111114
"src"
112115
]
116+
asyncio_mode = "auto"
117+
asyncio_default_fixture_loop_scope = "session"
118+
asyncio_default_test_loop_scope = "session"
113119
markers = [
114120
"slow: test or sets of tests which take more than a few seconds to run.",
115121
# While the `mut`ation marker below is not strictly necessary as every change is

src/config.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ root_path=""
1111
host="openml-test-database"
1212
port="3306"
1313
# SQLAlchemy `dialect` and `driver`: https://docs.sqlalchemy.org/en/20/dialects/index.html
14-
drivername="mysql"
14+
drivername="mysql+aiomysql"
1515

1616
[databases.expdb]
1717
database="openml_expdb"

src/core/access.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
1+
from typing import Any
2+
13
from sqlalchemy.engine import Row
24

35
from database.users import User, UserGroup
46
from schemas.datasets.openml import Visibility
57

68

7-
def _user_has_access(
8-
dataset: Row,
9+
async def _user_has_access(
10+
dataset: Row[Any],
911
user: User | None = None,
1012
) -> bool:
1113
"""Determine if `user` has the right to view `dataset`."""
12-
is_public = dataset.visibility == Visibility.PUBLIC
13-
return is_public or (
14-
user is not None and (user.user_id == dataset.uploader or UserGroup.ADMIN in user.groups)
15-
)
14+
if dataset.visibility == Visibility.PUBLIC:
15+
return True
16+
if user is None:
17+
return False
18+
if user.user_id == dataset.uploader:
19+
return True
20+
user_groups = await user.get_groups()
21+
return UserGroup.ADMIN in user_groups

src/database/datasets.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
import datetime
44

5-
from sqlalchemy import Connection, text
5+
from sqlalchemy import text
66
from sqlalchemy.engine import Row
7+
from sqlalchemy.ext.asyncio import AsyncConnection
78

89
from schemas.datasets.openml import Feature
910

1011

11-
def get(id_: int, connection: Connection) -> Row | None:
12-
row = connection.execute(
12+
async def get(id_: int, connection: AsyncConnection) -> Row | None:
13+
row = await connection.execute(
1314
text(
1415
"""
1516
SELECT *
@@ -22,8 +23,8 @@ def get(id_: int, connection: Connection) -> Row | None:
2223
return row.one_or_none()
2324

2425

25-
def get_file(*, file_id: int, connection: Connection) -> Row | None:
26-
row = connection.execute(
26+
async def get_file(*, file_id: int, connection: AsyncConnection) -> Row | None:
27+
row = await connection.execute(
2728
text(
2829
"""
2930
SELECT *
@@ -36,8 +37,8 @@ def get_file(*, file_id: int, connection: Connection) -> Row | None:
3637
return row.one_or_none()
3738

3839

39-
def get_tags_for(id_: int, connection: Connection) -> list[str]:
40-
rows = connection.execute(
40+
async def get_tags_for(id_: int, connection: AsyncConnection) -> list[str]:
41+
row = await connection.execute(
4142
text(
4243
"""
4344
SELECT *
@@ -47,11 +48,12 @@ def get_tags_for(id_: int, connection: Connection) -> list[str]:
4748
),
4849
parameters={"dataset_id": id_},
4950
)
51+
rows = row.all()
5052
return [row.tag for row in rows]
5153

5254

53-
def tag(id_: int, tag_: str, *, user_id: int, connection: Connection) -> None:
54-
connection.execute(
55+
async def tag(id_: int, tag_: str, *, user_id: int, connection: AsyncConnection) -> None:
56+
await connection.execute(
5557
text(
5658
"""
5759
INSERT INTO dataset_tag(`id`, `tag`, `uploader`)
@@ -66,12 +68,12 @@ def tag(id_: int, tag_: str, *, user_id: int, connection: Connection) -> None:
6668
)
6769

6870

69-
def get_description(
71+
async def get_description(
7072
id_: int,
71-
connection: Connection,
73+
connection: AsyncConnection,
7274
) -> Row | None:
7375
"""Get the most recent description for the dataset."""
74-
row = connection.execute(
76+
row = await connection.execute(
7577
text(
7678
"""
7779
SELECT *
@@ -85,9 +87,9 @@ def get_description(
8587
return row.first()
8688

8789

88-
def get_status(id_: int, connection: Connection) -> Row | None:
90+
async def get_status(id_: int, connection: AsyncConnection) -> Row | None:
8991
"""Get most recent status for the dataset."""
90-
row = connection.execute(
92+
row = await connection.execute(
9193
text(
9294
"""
9395
SELECT *
@@ -101,8 +103,8 @@ def get_status(id_: int, connection: Connection) -> Row | None:
101103
return row.first()
102104

103105

104-
def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row | None:
105-
row = connection.execute(
106+
async def get_latest_processing_update(dataset_id: int, connection: AsyncConnection) -> Row | None:
107+
row = await connection.execute(
106108
text(
107109
"""
108110
SELECT *
@@ -116,8 +118,8 @@ def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row
116118
return row.first()
117119

118120

119-
def get_features(dataset_id: int, connection: Connection) -> list[Feature]:
120-
rows = connection.execute(
121+
async def get_features(dataset_id: int, connection: AsyncConnection) -> list[Feature]:
122+
row = await connection.execute(
121123
text(
122124
"""
123125
SELECT `index`,`name`,`data_type`,`is_target`,
@@ -128,11 +130,17 @@ def get_features(dataset_id: int, connection: Connection) -> list[Feature]:
128130
),
129131
parameters={"dataset_id": dataset_id},
130132
)
131-
return [Feature(**row, nominal_values=None) for row in rows.mappings()]
133+
rows = row.mappings().all()
134+
return [Feature(**row, nominal_values=None) for row in rows]
132135

133136

134-
def get_feature_values(dataset_id: int, *, feature_index: int, connection: Connection) -> list[str]:
135-
rows = connection.execute(
137+
async def get_feature_values(
138+
dataset_id: int,
139+
*,
140+
feature_index: int,
141+
connection: AsyncConnection,
142+
) -> list[str]:
143+
row = await connection.execute(
136144
text(
137145
"""
138146
SELECT `value`
@@ -142,17 +150,18 @@ def get_feature_values(dataset_id: int, *, feature_index: int, connection: Conne
142150
),
143151
parameters={"dataset_id": dataset_id, "feature_index": feature_index},
144152
)
153+
rows = row.all()
145154
return [row.value for row in rows]
146155

147156

148-
def update_status(
157+
async def update_status(
149158
dataset_id: int,
150159
status: str,
151160
*,
152161
user_id: int,
153-
connection: Connection,
162+
connection: AsyncConnection,
154163
) -> None:
155-
connection.execute(
164+
await connection.execute(
156165
text(
157166
"""
158167
INSERT INTO dataset_status(`did`,`status`,`status_date`,`user_id`)
@@ -168,8 +177,8 @@ def update_status(
168177
)
169178

170179

171-
def remove_deactivated_status(dataset_id: int, connection: Connection) -> None:
172-
connection.execute(
180+
async def remove_deactivated_status(dataset_id: int, connection: AsyncConnection) -> None:
181+
await connection.execute(
173182
text(
174183
"""
175184
DELETE FROM dataset_status

src/database/evaluations.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,32 @@
11
from collections.abc import Sequence
22
from typing import cast
33

4-
from sqlalchemy import Connection, Row, text
4+
from sqlalchemy import Row, text
5+
from sqlalchemy.ext.asyncio import AsyncConnection
56

67
from core.formatting import _str_to_bool
78
from schemas.datasets.openml import EstimationProcedure
89

910

10-
def get_math_functions(function_type: str, connection: Connection) -> Sequence[Row]:
11-
return cast(
12-
"Sequence[Row]",
13-
connection.execute(
14-
text(
15-
"""
11+
async def get_math_functions(function_type: str, connection: AsyncConnection) -> Sequence[Row]:
12+
rows = await connection.execute(
13+
text(
14+
"""
1615
SELECT *
1716
FROM math_function
1817
WHERE `functionType` = :function_type
1918
""",
20-
),
21-
parameters={"function_type": function_type},
22-
).all(),
19+
),
20+
parameters={"function_type": function_type},
21+
)
22+
return cast(
23+
"Sequence[Row]",
24+
rows.all(),
2325
)
2426

2527

26-
def get_estimation_procedures(connection: Connection) -> list[EstimationProcedure]:
27-
rows = connection.execute(
28+
async def get_estimation_procedures(connection: AsyncConnection) -> list[EstimationProcedure]:
29+
row = await connection.execute(
2830
text(
2931
"""
3032
SELECT `id` as 'id_', `ttid` as 'task_type_id', `name`, `type` as 'type_',
@@ -33,11 +35,12 @@ def get_estimation_procedures(connection: Connection) -> list[EstimationProcedur
3335
""",
3436
),
3537
)
38+
rows = row.mappings().all()
3639
typed_rows = [
3740
{
3841
k: v if k != "stratified_sampling" or v is None else _str_to_bool(v)
3942
for k, v in row.items()
4043
}
41-
for row in rows.mappings()
44+
for row in rows
4245
]
4346
return [EstimationProcedure(**typed_row) for typed_row in typed_rows]

src/database/flows.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
from collections.abc import Sequence
22
from typing import cast
33

4-
from sqlalchemy import Connection, Row, text
4+
from sqlalchemy import Row, text
5+
from sqlalchemy.ext.asyncio import AsyncConnection
56

67

7-
def get_subflows(for_flow: int, expdb: Connection) -> Sequence[Row]:
8-
return cast(
9-
"Sequence[Row]",
10-
expdb.execute(
11-
text(
12-
"""
8+
async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]:
9+
rows = await expdb.execute(
10+
text(
11+
"""
1312
SELECT child as child_id, identifier
1413
FROM implementation_component
1514
WHERE parent = :flow_id
1615
""",
17-
),
18-
parameters={"flow_id": for_flow},
1916
),
17+
parameters={"flow_id": for_flow},
18+
)
19+
return cast(
20+
"Sequence[Row]",
21+
rows.all(),
2022
)
2123

2224

23-
def get_tags(flow_id: int, expdb: Connection) -> list[str]:
24-
tag_rows = expdb.execute(
25+
async def get_tags(flow_id: int, expdb: AsyncConnection) -> list[str]:
26+
rows = await expdb.execute(
2527
text(
2628
"""
2729
SELECT tag
@@ -31,28 +33,30 @@ def get_tags(flow_id: int, expdb: Connection) -> list[str]:
3133
),
3234
parameters={"flow_id": flow_id},
3335
)
36+
tag_rows = rows.all()
3437
return [tag.tag for tag in tag_rows]
3538

3639

37-
def get_parameters(flow_id: int, expdb: Connection) -> Sequence[Row]:
38-
return cast(
39-
"Sequence[Row]",
40-
expdb.execute(
41-
text(
42-
"""
40+
async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]:
41+
rows = await expdb.execute(
42+
text(
43+
"""
4344
SELECT *, defaultValue as default_value, dataType as data_type
4445
FROM input
4546
WHERE implementation_id = :flow_id
4647
""",
47-
),
48-
parameters={"flow_id": flow_id},
4948
),
49+
parameters={"flow_id": flow_id},
50+
)
51+
return cast(
52+
"Sequence[Row]",
53+
rows.all(),
5054
)
5155

5256

53-
def get_by_name(name: str, external_version: str, expdb: Connection) -> Row | None:
57+
async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) -> Row | None:
5458
"""Get flow by name and external version."""
55-
return expdb.execute(
59+
row = await expdb.execute(
5660
text(
5761
"""
5862
SELECT *, uploadDate as upload_date
@@ -61,11 +65,12 @@ def get_by_name(name: str, external_version: str, expdb: Connection) -> Row | No
6165
""",
6266
),
6367
parameters={"name": name, "external_version": external_version},
64-
).one_or_none()
68+
)
69+
return row.one_or_none()
6570

6671

67-
def get(id_: int, expdb: Connection) -> Row | None:
68-
return expdb.execute(
72+
async def get(id_: int, expdb: AsyncConnection) -> Row | None:
73+
row = await expdb.execute(
6974
text(
7075
"""
7176
SELECT *, uploadDate as upload_date
@@ -74,4 +79,5 @@ def get(id_: int, expdb: Connection) -> Row | None:
7479
""",
7580
),
7681
parameters={"flow_id": id_},
77-
).one_or_none()
82+
)
83+
return row.one_or_none()

0 commit comments

Comments
 (0)