22
33import datetime
44
5- from sqlalchemy import Connection , text
5+ from sqlalchemy import text
66from sqlalchemy .engine import Row
7+ from sqlalchemy .ext .asyncio import AsyncConnection
78
89from schemas .datasets .openml import Feature
910
1011
11- def get (id_ : int , connection : Connection ) -> Row | None :
12- row = connection .execute (
12+ async def get (id_ : int , connection : AsyncConnection ) -> Row | None :
13+ row = await connection .execute (
1314 text (
1415 """
1516 SELECT *
@@ -22,8 +23,8 @@ def get(id_: int, connection: Connection) -> Row | None:
2223 return row .one_or_none ()
2324
2425
25- def get_file (* , file_id : int , connection : Connection ) -> Row | None :
26- row = connection .execute (
26+ async def get_file (* , file_id : int , connection : AsyncConnection ) -> Row | None :
27+ row = await connection .execute (
2728 text (
2829 """
2930 SELECT *
@@ -36,8 +37,8 @@ def get_file(*, file_id: int, connection: Connection) -> Row | None:
3637 return row .one_or_none ()
3738
3839
39- def get_tags_for (id_ : int , connection : Connection ) -> list [str ]:
40- rows = connection .execute (
40+ async def get_tags_for (id_ : int , connection : AsyncConnection ) -> list [str ]:
41+ row = await connection .execute (
4142 text (
4243 """
4344 SELECT *
@@ -47,11 +48,12 @@ def get_tags_for(id_: int, connection: Connection) -> list[str]:
4748 ),
4849 parameters = {"dataset_id" : id_ },
4950 )
51+ rows = row .all ()
5052 return [row .tag for row in rows ]
5153
5254
53- def tag (id_ : int , tag_ : str , * , user_id : int , connection : Connection ) -> None :
54- connection .execute (
55+ async def tag (id_ : int , tag_ : str , * , user_id : int , connection : AsyncConnection ) -> None :
56+ await connection .execute (
5557 text (
5658 """
5759 INSERT INTO dataset_tag(`id`, `tag`, `uploader`)
@@ -66,12 +68,12 @@ def tag(id_: int, tag_: str, *, user_id: int, connection: Connection) -> None:
6668 )
6769
6870
69- def get_description (
71+ async def get_description (
7072 id_ : int ,
71- connection : Connection ,
73+ connection : AsyncConnection ,
7274) -> Row | None :
7375 """Get the most recent description for the dataset."""
74- row = connection .execute (
76+ row = await connection .execute (
7577 text (
7678 """
7779 SELECT *
@@ -85,9 +87,9 @@ def get_description(
8587 return row .first ()
8688
8789
88- def get_status (id_ : int , connection : Connection ) -> Row | None :
90+ async def get_status (id_ : int , connection : AsyncConnection ) -> Row | None :
8991 """Get most recent status for the dataset."""
90- row = connection .execute (
92+ row = await connection .execute (
9193 text (
9294 """
9395 SELECT *
@@ -101,8 +103,8 @@ def get_status(id_: int, connection: Connection) -> Row | None:
101103 return row .first ()
102104
103105
104- def get_latest_processing_update (dataset_id : int , connection : Connection ) -> Row | None :
105- row = connection .execute (
106+ async def get_latest_processing_update (dataset_id : int , connection : AsyncConnection ) -> Row | None :
107+ row = await connection .execute (
106108 text (
107109 """
108110 SELECT *
@@ -116,8 +118,8 @@ def get_latest_processing_update(dataset_id: int, connection: Connection) -> Row
116118 return row .first ()
117119
118120
119- def get_features (dataset_id : int , connection : Connection ) -> list [Feature ]:
120- rows = connection .execute (
121+ async def get_features (dataset_id : int , connection : AsyncConnection ) -> list [Feature ]:
122+ row = await connection .execute (
121123 text (
122124 """
123125 SELECT `index`,`name`,`data_type`,`is_target`,
@@ -128,11 +130,17 @@ def get_features(dataset_id: int, connection: Connection) -> list[Feature]:
128130 ),
129131 parameters = {"dataset_id" : dataset_id },
130132 )
131- return [Feature (** row , nominal_values = None ) for row in rows .mappings ()]
133+ rows = row .mappings ().all ()
134+ return [Feature (** row , nominal_values = None ) for row in rows ]
132135
133136
134- def get_feature_values (dataset_id : int , * , feature_index : int , connection : Connection ) -> list [str ]:
135- rows = connection .execute (
137+ async def get_feature_values (
138+ dataset_id : int ,
139+ * ,
140+ feature_index : int ,
141+ connection : AsyncConnection ,
142+ ) -> list [str ]:
143+ row = await connection .execute (
136144 text (
137145 """
138146 SELECT `value`
@@ -142,17 +150,18 @@ def get_feature_values(dataset_id: int, *, feature_index: int, connection: Conne
142150 ),
143151 parameters = {"dataset_id" : dataset_id , "feature_index" : feature_index },
144152 )
153+ rows = row .all ()
145154 return [row .value for row in rows ]
146155
147156
148- def update_status (
157+ async def update_status (
149158 dataset_id : int ,
150159 status : str ,
151160 * ,
152161 user_id : int ,
153- connection : Connection ,
162+ connection : AsyncConnection ,
154163) -> None :
155- connection .execute (
164+ await connection .execute (
156165 text (
157166 """
158167 INSERT INTO dataset_status(`did`,`status`,`status_date`,`user_id`)
@@ -168,8 +177,8 @@ def update_status(
168177 )
169178
170179
171- def remove_deactivated_status (dataset_id : int , connection : Connection ) -> None :
172- connection .execute (
180+ async def remove_deactivated_status (dataset_id : int , connection : AsyncConnection ) -> None :
181+ await connection .execute (
173182 text (
174183 """
175184 DELETE FROM dataset_status
0 commit comments