Skip to content

Commit f94808c

Browse files
igennovaigen novaigennovaPGijsberspre-commit-ci[bot]
authored
[ENH] Add POST /setup/untag endpoint (#246)
Fixes #65 This PR migrates the `POST /setup/untag` endpoint from the legacy PHP API to the new FastAPI backend. It is the first endpoint migrated under the Setup Epic (#60). --------- Co-authored-by: igen nova <igennova@igens-MacBook-Pro.local> Co-authored-by: igennova <luckynegi025@gmail.com> Co-authored-by: Pieter Gijsbers <p.gijsbers@tue.nl> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1dcfbde commit f94808c

9 files changed

Lines changed: 379 additions & 8 deletions

File tree

src/core/errors.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,24 @@ class TagAlreadyExistsError(ProblemDetailError):
219219
_default_code = 473
220220

221221

222+
class TagNotFoundError(ProblemDetailError):
223+
"""Raised when trying to remove or retrieve a tag that does not exist."""
224+
225+
uri = "https://openml.org/problems/tag-not-found"
226+
title = "Tag Not Found"
227+
_default_status_code = HTTPStatus.NOT_FOUND
228+
_default_code = 475
229+
230+
231+
class TagNotOwnedError(ProblemDetailError):
232+
"""Raised when trying to remove a tag that was created by someone else."""
233+
234+
uri = "https://openml.org/problems/tag-not-owned"
235+
title = "Tag Not Owned"
236+
_default_status_code = HTTPStatus.FORBIDDEN
237+
_default_code = 476
238+
239+
222240
# =============================================================================
223241
# Search/List Errors
224242
# =============================================================================
@@ -329,6 +347,20 @@ class FlowNotFoundError(ProblemDetailError):
329347
_default_status_code = HTTPStatus.NOT_FOUND
330348

331349

350+
# =============================================================================
351+
# Setup Errors
352+
# =============================================================================
353+
354+
355+
class SetupNotFoundError(ProblemDetailError):
356+
"""Raised when a setup cannot be found."""
357+
358+
uri = "https://openml.org/problems/setup-not-found"
359+
title = "Setup Not Found"
360+
_default_status_code = HTTPStatus.NOT_FOUND
361+
_default_code = 472
362+
363+
332364
# =============================================================================
333365
# Service Errors
334366
# =============================================================================

src/database/setups.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""All database operations that directly operate on setups."""
2+
3+
from sqlalchemy import text
4+
from sqlalchemy.engine import Row
5+
from sqlalchemy.ext.asyncio import AsyncConnection
6+
7+
8+
async def get(setup_id: int, connection: AsyncConnection) -> Row | None:
9+
"""Get the setup with id `setup_id` from the database."""
10+
row = await connection.execute(
11+
text(
12+
"""
13+
SELECT *
14+
FROM algorithm_setup
15+
WHERE sid = :setup_id
16+
""",
17+
),
18+
parameters={"setup_id": setup_id},
19+
)
20+
return row.first()
21+
22+
23+
async def get_tags(setup_id: int, connection: AsyncConnection) -> list[Row]:
24+
"""Get all tags for setup with `setup_id` from the database."""
25+
rows = await connection.execute(
26+
text(
27+
"""
28+
SELECT *
29+
FROM setup_tag
30+
WHERE id = :setup_id
31+
""",
32+
),
33+
parameters={"setup_id": setup_id},
34+
)
35+
return list(rows.all())
36+
37+
38+
async def untag(setup_id: int, tag: str, connection: AsyncConnection) -> None:
39+
"""Remove tag `tag` from setup with id `setup_id`."""
40+
await connection.execute(
41+
text(
42+
"""
43+
DELETE FROM setup_tag
44+
WHERE id = :setup_id AND tag = :tag
45+
""",
46+
),
47+
parameters={"setup_id": setup_id, "tag": tag},
48+
)

src/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from routers.openml.evaluations import router as evaluationmeasures_router
1616
from routers.openml.flows import router as flows_router
1717
from routers.openml.qualities import router as qualities_router
18+
from routers.openml.setups import router as setup_router
1819
from routers.openml.study import router as study_router
1920
from routers.openml.tasks import router as task_router
2021
from routers.openml.tasktype import router as ttype_router
@@ -68,6 +69,7 @@ def create_api() -> FastAPI:
6869
app.include_router(task_router)
6970
app.include_router(flows_router)
7071
app.include_router(study_router)
72+
app.include_router(setup_router)
7173
return app
7274

7375

src/routers/dependencies.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pydantic import BaseModel
66
from sqlalchemy.ext.asyncio import AsyncConnection
77

8+
from core.errors import AuthenticationFailedError
89
from database.setup import expdb_database, user_database
910
from database.users import APIKey, User
1011

@@ -28,6 +29,15 @@ async def fetch_user(
2829
return await User.fetch(api_key, user_data) if api_key and user_data else None
2930

3031

32+
def fetch_user_or_raise(
33+
user: Annotated[User | None, Depends(fetch_user)] = None,
34+
) -> User:
35+
if user is None:
36+
msg = "Authentication failed"
37+
raise AuthenticationFailedError(msg)
38+
return user
39+
40+
3141
class Pagination(BaseModel):
3242
offset: int = 0
3343
limit: int = 100

src/routers/openml/datasets.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import database.qualities
1313
from core.access import _user_has_access
1414
from core.errors import (
15-
AuthenticationFailedError,
1615
AuthenticationRequiredError,
1716
DatasetAdminOnlyError,
1817
DatasetNoAccessError,
@@ -33,7 +32,13 @@
3332
_format_parquet_url,
3433
)
3534
from database.users import User, UserGroup
36-
from routers.dependencies import Pagination, expdb_connection, fetch_user, userdb_connection
35+
from routers.dependencies import (
36+
Pagination,
37+
expdb_connection,
38+
fetch_user,
39+
fetch_user_or_raise,
40+
userdb_connection,
41+
)
3742
from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex
3843
from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType
3944

@@ -46,7 +51,7 @@
4651
async def tag_dataset(
4752
data_id: Annotated[int, Body()],
4853
tag: Annotated[str, SystemString64],
49-
user: Annotated[User | None, Depends(fetch_user)] = None,
54+
user: Annotated[User, Depends(fetch_user_or_raise)],
5055
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
5156
) -> dict[str, dict[str, Any]]:
5257
assert expdb_db is not None # noqa: S101
@@ -55,10 +60,6 @@ async def tag_dataset(
5560
msg = f"Dataset {data_id} already tagged with {tag!r}."
5661
raise TagAlreadyExistsError(msg)
5762

58-
if user is None:
59-
msg = "Authentication failed."
60-
raise AuthenticationFailedError(msg)
61-
6263
await database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db)
6364
return {
6465
"data_tag": {"id": str(data_id), "tag": [*tags, tag]},

src/routers/openml/setups.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""All endpoints that relate to setups."""
2+
3+
from typing import Annotated
4+
5+
from fastapi import APIRouter, Body, Depends
6+
from sqlalchemy.ext.asyncio import AsyncConnection
7+
8+
import database.setups
9+
from core.errors import SetupNotFoundError, TagNotFoundError, TagNotOwnedError
10+
from database.users import User, UserGroup
11+
from routers.dependencies import expdb_connection, fetch_user_or_raise
12+
from routers.types import SystemString64
13+
14+
router = APIRouter(prefix="/setup", tags=["setup"])
15+
16+
17+
@router.post(path="/untag")
18+
async def untag_setup(
19+
setup_id: Annotated[int, Body()],
20+
tag: Annotated[str, SystemString64],
21+
user: Annotated[User, Depends(fetch_user_or_raise)],
22+
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
23+
) -> dict[str, dict[str, str | list[str]]]:
24+
"""Remove tag `tag` from setup with id `setup_id`."""
25+
if not await database.setups.get(setup_id, expdb_db):
26+
msg = f"Setup {setup_id} not found."
27+
raise SetupNotFoundError(msg)
28+
29+
setup_tags = await database.setups.get_tags(setup_id, expdb_db)
30+
matched_tag_row = next((t for t in setup_tags if t.tag.casefold() == tag.casefold()), None)
31+
32+
if not matched_tag_row:
33+
msg = f"Setup {setup_id} does not have tag {tag!r}."
34+
raise TagNotFoundError(msg)
35+
36+
if matched_tag_row.uploader != user.user_id and UserGroup.ADMIN not in await user.get_groups():
37+
msg = (
38+
f"You may not remove tag {tag!r} of setup {setup_id} because it was not created by you."
39+
)
40+
raise TagNotOwnedError(msg)
41+
42+
await database.setups.untag(setup_id, matched_tag_row.tag, expdb_db)
43+
remaining_tags = [t.tag.casefold() for t in setup_tags if t != matched_tag_row]
44+
return {"setup_untag": {"id": str(setup_id), "tag": remaining_tags}}

tests/routers/openml/dataset_tag_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ async def test_dataset_tag_invalid_tag_is_rejected(
8383
py_api: httpx.AsyncClient,
8484
) -> None:
8585
new = await py_api.post(
86-
f"/datasets/tag?api_key{ApiKey.ADMIN}",
86+
f"/datasets/tag?api_key={ApiKey.ADMIN}",
8787
json={"data_id": 1, "tag": tag},
8888
)
8989

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import contextlib
2+
import re
3+
from collections.abc import AsyncGenerator, Iterable
4+
from http import HTTPStatus
5+
6+
import httpx
7+
import pytest
8+
from sqlalchemy import text
9+
from sqlalchemy.ext.asyncio import AsyncConnection
10+
11+
from tests.users import OWNER_USER, ApiKey
12+
13+
14+
@pytest.mark.parametrize(
15+
"api_key",
16+
[ApiKey.ADMIN, ApiKey.SOME_USER, ApiKey.OWNER_USER],
17+
ids=["Administrator", "non-owner", "tag owner"],
18+
)
19+
@pytest.mark.parametrize(
20+
"other_tags",
21+
[[], ["some_other_tag"], ["foo_some_other_tag", "bar_some_other_tag"]],
22+
ids=["none", "one tag", "two tags"],
23+
)
24+
async def test_setup_untag_response_is_identical_when_tag_exists(
25+
api_key: str,
26+
other_tags: list[str],
27+
py_api: httpx.AsyncClient,
28+
php_api: httpx.AsyncClient,
29+
expdb_test: AsyncConnection,
30+
) -> None:
31+
setup_id = 1
32+
tag = "totally_new_tag_for_migration_testing"
33+
34+
@contextlib.asynccontextmanager
35+
async def temporary_tags(
36+
tags: Iterable[str], setup_id: int, *, persist: bool = False
37+
) -> AsyncGenerator[None]:
38+
for tag in tags:
39+
await expdb_test.execute(
40+
text(
41+
"INSERT INTO setup_tag(`id`,`tag`,`uploader`) VALUES (:setup_id, :tag, :user_id);" # noqa: E501
42+
),
43+
parameters={"setup_id": setup_id, "tag": tag, "user_id": OWNER_USER.user_id},
44+
)
45+
if persist:
46+
await expdb_test.commit()
47+
yield
48+
for tag in tags:
49+
await expdb_test.execute(
50+
text("DELETE FROM setup_tag WHERE `id`=:setup_id AND `tag`=:tag"),
51+
parameters={"setup_id": setup_id, "tag": tag},
52+
)
53+
if persist:
54+
await expdb_test.commit()
55+
56+
all_tags = [tag, *other_tags]
57+
async with temporary_tags(tags=all_tags, setup_id=setup_id, persist=True):
58+
original = await php_api.post(
59+
"/setup/untag",
60+
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
61+
)
62+
63+
# expdb_test transaction shared with Python API,
64+
# no commit needed and rolled back at the end of the test
65+
async with temporary_tags(tags=all_tags, setup_id=setup_id):
66+
new = await py_api.post(
67+
f"/setup/untag?api_key={api_key}",
68+
json={"setup_id": setup_id, "tag": tag},
69+
)
70+
71+
if new.status_code == HTTPStatus.OK:
72+
assert original.status_code == new.status_code
73+
original_untag = original.json()["setup_untag"]
74+
new_untag = new.json()["setup_untag"]
75+
assert original_untag["id"] == new_untag["id"]
76+
if tags := original_untag.get("tag"):
77+
if isinstance(tags, str):
78+
assert tags == new_untag["tag"][0]
79+
else:
80+
assert tags == new_untag["tag"]
81+
else:
82+
assert new_untag["tag"] == []
83+
return
84+
85+
code, message = original.json()["error"].values()
86+
assert original.status_code == HTTPStatus.PRECONDITION_FAILED
87+
assert new.status_code == HTTPStatus.FORBIDDEN
88+
assert code == new.json()["code"]
89+
assert message == "Tag is not owned by you"
90+
assert re.match(
91+
r"You may not remove tag \S+ of setup \d+ because it was not created by you.",
92+
new.json()["detail"],
93+
)
94+
95+
96+
async def test_setup_untag_response_is_identical_setup_doesnt_exist(
97+
py_api: httpx.AsyncClient,
98+
php_api: httpx.AsyncClient,
99+
) -> None:
100+
setup_id = 999999
101+
tag = "totally_new_tag_for_migration_testing"
102+
api_key = ApiKey.SOME_USER
103+
104+
original = await php_api.post(
105+
"/setup/untag",
106+
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
107+
)
108+
109+
new = await py_api.post(
110+
f"/setup/untag?api_key={api_key}",
111+
json={"setup_id": setup_id, "tag": tag},
112+
)
113+
114+
assert original.status_code == HTTPStatus.PRECONDITION_FAILED
115+
assert new.status_code == HTTPStatus.NOT_FOUND
116+
assert original.json()["error"]["message"] == "Entity not found."
117+
assert original.json()["error"]["code"] == new.json()["code"]
118+
assert re.match(
119+
r"Setup \d+ not found.",
120+
new.json()["detail"],
121+
)
122+
123+
124+
async def test_setup_untag_response_is_identical_tag_doesnt_exist(
125+
py_api: httpx.AsyncClient,
126+
php_api: httpx.AsyncClient,
127+
) -> None:
128+
setup_id = 1
129+
tag = "totally_new_tag_for_migration_testing"
130+
api_key = ApiKey.SOME_USER
131+
132+
original = await php_api.post(
133+
"/setup/untag",
134+
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
135+
)
136+
137+
new = await py_api.post(
138+
f"/setup/untag?api_key={api_key}",
139+
json={"setup_id": setup_id, "tag": tag},
140+
)
141+
142+
assert original.status_code == HTTPStatus.PRECONDITION_FAILED
143+
assert new.status_code == HTTPStatus.NOT_FOUND
144+
assert original.json()["error"]["code"] == new.json()["code"]
145+
assert original.json()["error"]["message"] == "Tag not found."
146+
assert re.match(
147+
r"Setup \d+ does not have tag '\S+'.",
148+
new.json()["detail"],
149+
)

0 commit comments

Comments
 (0)