Skip to content

Commit c6af844

Browse files
committed
Fix dependency tests
1 parent 7135e25 commit c6af844

3 files changed

Lines changed: 11 additions & 4 deletions

File tree

src/routers/dependencies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ async def userdb_connection() -> AsyncGenerator[AsyncConnection, None]:
2626
async def fetch_user(
2727
api_key: APIKey | None = None,
2828
user_data: Annotated[AsyncConnection | None, Depends(userdb_connection)] = None,
29-
) -> AsyncGenerator[User | None]:
29+
) -> AsyncGenerator[User | None, None]:
3030
if not (api_key and user_data):
3131
yield None
3232
return

src/routers/openml/study.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ async def attach_to_study(
103103
"User {user_id} attached entities to study {study_id}.",
104104
study_id=study_id,
105105
entity_ids=entity_ids,
106+
user_id=user.user_id,
106107
)
107108
return AttachDetachResponse(study_id=study_id, main_entity_type=study.type_)
108109

@@ -138,6 +139,7 @@ async def create_study(
138139
logger.info(
139140
"User {user_id} created study {study_id}.",
140141
study_id=study_id,
142+
user_id=user.user_id,
141143
)
142144
# Make sure that invalid fields raise an error (e.g., "task_ids")
143145
return {"study_id": study_id}

tests/dependencies/fetch_user_test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from contextlib import aclosing
2+
13
import pytest
24
from sqlalchemy.ext.asyncio import AsyncConnection
35

@@ -16,19 +18,22 @@
1618
],
1719
)
1820
async def test_fetch_user(api_key: str, user: User, user_test: AsyncConnection) -> None:
19-
db_user = await fetch_user(api_key, user_data=user_test)
21+
async with aclosing(fetch_user(api_key, user_data=user_test)) as agen:
22+
db_user = await anext(agen)
2023
assert isinstance(db_user, User)
2124
assert user.user_id == db_user.user_id
2225
assert set(await user.get_groups()) == set(await db_user.get_groups())
2326

2427

2528
async def test_fetch_user_no_key_no_user() -> None:
26-
assert await fetch_user(api_key=None) is None
29+
async with aclosing(fetch_user(api_key=None)) as agen:
30+
assert await anext(agen) is None
2731

2832

2933
async def test_fetch_user_invalid_key_raises(user_test: AsyncConnection) -> None:
3034
with pytest.raises(AuthenticationFailedError):
31-
await fetch_user(api_key=ApiKey.INVALID, user_data=user_test)
35+
async with aclosing(fetch_user(api_key=ApiKey.INVALID, user_data=user_test)) as agen:
36+
await anext(agen)
3237

3338

3439
async def test_fetch_user_or_raise_raises_if_no_user() -> None:

0 commit comments

Comments
 (0)