Skip to content

Commit 562c413

Browse files
feat: Add GET/POST /task/list endpoint (#277)
## Description Implements `GET /tasks/list` and `POST /tasks/list` endpoints, migrating from the PHP API. Fixes: #23 ### Filters supported `task_type_id`, `tag`, `data_tag`, `status`, `limit`, `offset`, `task_id`, `data_id`, `data_name`, `number_instances`, `number_features`, `number_classes`, `number_missing_values` ### Implementation notes - Follows the pattern established in `datasets.py` — filters via request body, both GET and POST decorators on the same function - Inline SQL in router (no new `database/` file needed) - Default status filter: `active` (matches PHP behavior) - Qualities returned are the same 10 as in `list_datasets` - Tables: `task`, `task_type`, `task_inputs`, `dataset`, `dataset_status`, `data_quality`, `task_tag` ## Checklist - [x] I have performed a self-review of my own pull request - [x] Tests pass locally - [x] I have commented my code in hard-to-understand areas, and provided or updated docstrings as needed - [x] I have added tests that cover the changes
1 parent 9066125 commit 562c413

3 files changed

Lines changed: 656 additions & 5 deletions

File tree

src/routers/openml/tasks.py

Lines changed: 257 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
import asyncio
22
import json
33
import re
4-
from typing import Annotated, cast
4+
from enum import StrEnum
5+
from typing import Annotated, Any, cast
56

67
import xmltodict
7-
from fastapi import APIRouter, Depends
8-
from sqlalchemy import RowMapping, text
8+
from fastapi import APIRouter, Body, Depends
9+
from sqlalchemy import bindparam, text
10+
from sqlalchemy.engine import RowMapping
911
from sqlalchemy.ext.asyncio import AsyncConnection
1012

1113
import config
1214
import database.datasets
1315
import database.tasks
14-
from core.errors import InternalError, TaskNotFoundError
15-
from routers.dependencies import expdb_connection
16+
from core.errors import InternalError, NoResultsError, TaskNotFoundError
17+
from routers.dependencies import Pagination, expdb_connection
18+
from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex
1619
from schemas.datasets.openml import Task
1720

1821
router = APIRouter(prefix="/tasks", tags=["tasks"])
@@ -158,6 +161,255 @@ async def _fill_json_template( # noqa: C901
158161
return template.replace("[CONSTANT:base_url]", server_url)
159162

160163

164+
class TaskStatusFilter(StrEnum):
165+
"""Valid values for the status filter."""
166+
167+
ACTIVE = "active"
168+
DEACTIVATED = "deactivated"
169+
IN_PREPARATION = "in_preparation"
170+
ALL = "all"
171+
172+
173+
QUALITIES_TO_SHOW = [
174+
"MajorityClassSize",
175+
"MaxNominalAttDistinctValues",
176+
"MinorityClassSize",
177+
"NumberOfClasses",
178+
"NumberOfFeatures",
179+
"NumberOfInstances",
180+
"NumberOfInstancesWithMissingValues",
181+
"NumberOfMissingValues",
182+
"NumberOfNumericFeatures",
183+
"NumberOfSymbolicFeatures",
184+
]
185+
186+
BASIC_TASK_INPUTS = [
187+
"source_data",
188+
"target_feature",
189+
"estimation_procedure",
190+
"evaluation_measures",
191+
]
192+
193+
194+
def _quality_clause(quality: str, range_: str | None) -> str:
195+
"""Return a SQL WHERE clause fragment filtering tasks by a dataset quality range.
196+
197+
Looks up tasks whose source dataset has the given quality within the range.
198+
Range can be exact ('100') or a range ('50..200').
199+
"""
200+
if not range_:
201+
return ""
202+
if not (match := re.match(integer_range_regex, range_)):
203+
msg = f"`range_` not a valid range: {range_}"
204+
raise ValueError(msg)
205+
start, end = match.groups()
206+
# end group looks like "..200", strip the ".." prefix to get just the number
207+
value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}"
208+
# nested subquery: find datasets with matching quality, then find tasks using those datasets
209+
return f"""
210+
AND t.`task_id` IN (
211+
SELECT ti.`task_id` FROM task_inputs ti
212+
WHERE ti.`input`='source_data' AND ti.`value` IN (
213+
SELECT `data` FROM data_quality
214+
WHERE `quality`='{quality}' AND {value}
215+
)
216+
)
217+
""" # noqa: S608
218+
219+
220+
@router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.")
221+
@router.get(path="/list")
222+
async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915
223+
pagination: Annotated[Pagination, Body(default_factory=Pagination)],
224+
task_type_id: Annotated[int | None, Body(description="Filter by task type id.")] = None,
225+
tag: Annotated[str | None, SystemString64] = None,
226+
data_tag: Annotated[str | None, SystemString64] = None,
227+
status: Annotated[TaskStatusFilter, Body()] = TaskStatusFilter.ACTIVE,
228+
task_id: Annotated[
229+
list[int] | None,
230+
Body(description="Filter by task id(s).", min_length=1),
231+
] = None,
232+
data_id: Annotated[
233+
list[int] | None,
234+
Body(description="Filter by dataset id(s).", min_length=1),
235+
] = None,
236+
data_name: Annotated[str | None, CasualString128] = None,
237+
number_instances: Annotated[str | None, IntegerRange] = None,
238+
number_features: Annotated[str | None, IntegerRange] = None,
239+
number_classes: Annotated[str | None, IntegerRange] = None,
240+
number_missing_values: Annotated[str | None, IntegerRange] = None,
241+
expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None,
242+
) -> list[dict[str, Any]]:
243+
"""List tasks, optionally filtered by type, tag, status, dataset properties, and more."""
244+
assert expdb is not None # noqa: S101
245+
246+
clauses: list[str] = []
247+
parameters: dict[str, Any] = {
248+
"offset": max(0, pagination.offset),
249+
"limit": max(0, pagination.limit),
250+
}
251+
252+
if status != TaskStatusFilter.ALL:
253+
clauses.append("AND IFNULL(ds.`status`, 'in_preparation') = :status")
254+
parameters["status"] = status
255+
256+
if task_type_id is not None:
257+
clauses.append("AND t.`ttid` = :task_type_id")
258+
parameters["task_type_id"] = task_type_id
259+
260+
if tag is not None:
261+
clauses.append("AND t.`task_id` IN (SELECT `id` FROM task_tag WHERE `tag` = :tag)")
262+
parameters["tag"] = tag
263+
264+
if data_tag is not None:
265+
clauses.append("AND d.`did` IN (SELECT `id` FROM dataset_tag WHERE `tag` = :data_tag)")
266+
parameters["data_tag"] = data_tag
267+
268+
if data_name is not None:
269+
clauses.append("AND d.`name` = :data_name")
270+
parameters["data_name"] = data_name
271+
272+
if task_id is not None:
273+
clauses.append("AND t.`task_id` IN :task_ids")
274+
parameters["task_ids"] = task_id
275+
276+
if data_id is not None:
277+
clauses.append("AND d.`did` IN :data_ids")
278+
parameters["data_ids"] = data_id
279+
280+
where_number_instances = _quality_clause("NumberOfInstances", number_instances)
281+
where_number_features = _quality_clause("NumberOfFeatures", number_features)
282+
where_number_classes = _quality_clause("NumberOfClasses", number_classes)
283+
where_number_missing_values = _quality_clause("NumberOfMissingValues", number_missing_values)
284+
285+
# subquery to get the latest status per dataset (dataset_status is a history table)
286+
status_subquery = """
287+
SELECT ds1.did, ds1.status
288+
FROM dataset_status ds1
289+
WHERE ds1.status_date = (
290+
SELECT MAX(ds2.status_date) FROM dataset_status ds2
291+
WHERE ds1.did = ds2.did
292+
)
293+
"""
294+
295+
main_query = text(
296+
f"""
297+
SELECT
298+
t.`task_id`,
299+
t.`ttid` AS task_type_id,
300+
tt.`name` AS task_type,
301+
d.`did`,
302+
d.`name`,
303+
d.`format`,
304+
IFNULL(ds.`status`, 'in_preparation') AS status
305+
FROM task t
306+
JOIN task_type tt
307+
ON tt.`ttid` = t.`ttid`
308+
JOIN task_inputs ti_source
309+
ON ti_source.`task_id` = t.`task_id`
310+
AND ti_source.`input` = 'source_data'
311+
JOIN dataset d
312+
ON d.`did` = ti_source.`value`
313+
LEFT JOIN ({status_subquery}) ds
314+
ON ds.`did` = d.`did`
315+
WHERE 1=1
316+
{where_number_instances}
317+
{where_number_features}
318+
{where_number_classes}
319+
{where_number_missing_values}
320+
{" ".join(clauses)}
321+
GROUP BY t.`task_id`, t.`ttid`, tt.`name`, d.`did`, d.`name`, d.`format`, ds.`status`
322+
ORDER BY t.`task_id`
323+
LIMIT :limit OFFSET :offset
324+
""", # noqa: S608
325+
)
326+
327+
if task_id is not None:
328+
main_query = main_query.bindparams(bindparam("task_ids", expanding=True))
329+
if data_id is not None:
330+
main_query = main_query.bindparams(bindparam("data_ids", expanding=True))
331+
332+
result = await expdb.execute(main_query, parameters=parameters)
333+
rows = result.mappings().all()
334+
335+
if not rows:
336+
msg = "No tasks match the search criteria."
337+
raise NoResultsError(msg, code="482")
338+
339+
columns = ["task_id", "task_type_id", "task_type", "did", "name", "format", "status"]
340+
tasks: dict[int, dict[str, Any]] = {
341+
row["task_id"]: {col: row[col] for col in columns} for row in rows
342+
}
343+
task_ids: list[int] = list(tasks.keys())
344+
dataset_ids: list[int] = list({t["did"] for t in tasks.values()})
345+
346+
inputs_query = text(
347+
"""
348+
SELECT `task_id`, `input`, `value`
349+
FROM task_inputs
350+
WHERE `task_id` IN :task_ids
351+
AND `input` IN :basic_inputs
352+
""",
353+
).bindparams(
354+
bindparam("task_ids", expanding=True),
355+
bindparam("basic_inputs", expanding=True),
356+
)
357+
qualities_query = text(
358+
"""
359+
SELECT `data`, `quality`, `value`
360+
FROM data_quality
361+
WHERE `data` IN :dataset_ids
362+
AND `quality` IN :quality_names
363+
""",
364+
).bindparams(
365+
bindparam("dataset_ids", expanding=True),
366+
bindparam("quality_names", expanding=True),
367+
)
368+
369+
tags_query = text(
370+
"""
371+
SELECT `id`, `tag`
372+
FROM task_tag
373+
WHERE `id` IN :task_ids
374+
""",
375+
).bindparams(bindparam("task_ids", expanding=True))
376+
377+
inputs_result, qualities_result, tags_result = await asyncio.gather(
378+
expdb.execute(
379+
inputs_query,
380+
parameters={"task_ids": task_ids, "basic_inputs": BASIC_TASK_INPUTS},
381+
),
382+
expdb.execute(
383+
qualities_query,
384+
parameters={"dataset_ids": dataset_ids, "quality_names": QUALITIES_TO_SHOW},
385+
),
386+
expdb.execute(
387+
tags_query,
388+
parameters={"task_ids": task_ids},
389+
),
390+
)
391+
392+
for row in inputs_result.all():
393+
tasks[row.task_id].setdefault("input", []).append(
394+
{"name": row.input, "value": row.value},
395+
)
396+
397+
# multiple tasks can reference the same dataset; map dataset_id -> [task_id, ...]
398+
did_to_task_ids: dict[int, list[int]] = {}
399+
for tid, t in tasks.items():
400+
did_to_task_ids.setdefault(t["did"], []).append(tid)
401+
for row in qualities_result.all():
402+
for tid in did_to_task_ids.get(row.data, []):
403+
tasks[tid].setdefault("quality", []).append(
404+
{"name": row.quality, "value": str(row.value)},
405+
)
406+
407+
for row in tags_result.all():
408+
tasks[row.id].setdefault("tag", []).append(row.tag)
409+
410+
return list(tasks.values())
411+
412+
161413
@router.get("/{task_id}")
162414
async def get_task(
163415
task_id: int,

0 commit comments

Comments
 (0)