|
1 | 1 | import asyncio |
2 | 2 | import json |
3 | 3 | import re |
4 | | -from typing import Annotated, cast |
| 4 | +from enum import StrEnum |
| 5 | +from typing import Annotated, Any, cast |
5 | 6 |
|
6 | 7 | import xmltodict |
7 | | -from fastapi import APIRouter, Depends |
8 | | -from sqlalchemy import RowMapping, text |
| 8 | +from fastapi import APIRouter, Body, Depends |
| 9 | +from sqlalchemy import bindparam, text |
| 10 | +from sqlalchemy.engine import RowMapping |
9 | 11 | from sqlalchemy.ext.asyncio import AsyncConnection |
10 | 12 |
|
11 | 13 | import config |
12 | 14 | import database.datasets |
13 | 15 | import database.tasks |
14 | | -from core.errors import InternalError, TaskNotFoundError |
15 | | -from routers.dependencies import expdb_connection |
| 16 | +from core.errors import InternalError, NoResultsError, TaskNotFoundError |
| 17 | +from routers.dependencies import Pagination, expdb_connection |
| 18 | +from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex |
16 | 19 | from schemas.datasets.openml import Task |
17 | 20 |
|
18 | 21 | router = APIRouter(prefix="/tasks", tags=["tasks"]) |
@@ -158,6 +161,255 @@ async def _fill_json_template( # noqa: C901 |
158 | 161 | return template.replace("[CONSTANT:base_url]", server_url) |
159 | 162 |
|
160 | 163 |
|
| 164 | +class TaskStatusFilter(StrEnum): |
| 165 | + """Valid values for the status filter.""" |
| 166 | + |
| 167 | + ACTIVE = "active" |
| 168 | + DEACTIVATED = "deactivated" |
| 169 | + IN_PREPARATION = "in_preparation" |
| 170 | + ALL = "all" |
| 171 | + |
| 172 | + |
| 173 | +QUALITIES_TO_SHOW = [ |
| 174 | + "MajorityClassSize", |
| 175 | + "MaxNominalAttDistinctValues", |
| 176 | + "MinorityClassSize", |
| 177 | + "NumberOfClasses", |
| 178 | + "NumberOfFeatures", |
| 179 | + "NumberOfInstances", |
| 180 | + "NumberOfInstancesWithMissingValues", |
| 181 | + "NumberOfMissingValues", |
| 182 | + "NumberOfNumericFeatures", |
| 183 | + "NumberOfSymbolicFeatures", |
| 184 | +] |
| 185 | + |
| 186 | +BASIC_TASK_INPUTS = [ |
| 187 | + "source_data", |
| 188 | + "target_feature", |
| 189 | + "estimation_procedure", |
| 190 | + "evaluation_measures", |
| 191 | +] |
| 192 | + |
| 193 | + |
| 194 | +def _quality_clause(quality: str, range_: str | None) -> str: |
| 195 | + """Return a SQL WHERE clause fragment filtering tasks by a dataset quality range. |
| 196 | +
|
| 197 | + Looks up tasks whose source dataset has the given quality within the range. |
| 198 | + Range can be exact ('100') or a range ('50..200'). |
| 199 | + """ |
| 200 | + if not range_: |
| 201 | + return "" |
| 202 | + if not (match := re.match(integer_range_regex, range_)): |
| 203 | + msg = f"`range_` not a valid range: {range_}" |
| 204 | + raise ValueError(msg) |
| 205 | + start, end = match.groups() |
| 206 | + # end group looks like "..200", strip the ".." prefix to get just the number |
| 207 | + value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}" |
| 208 | + # nested subquery: find datasets with matching quality, then find tasks using those datasets |
| 209 | + return f""" |
| 210 | + AND t.`task_id` IN ( |
| 211 | + SELECT ti.`task_id` FROM task_inputs ti |
| 212 | + WHERE ti.`input`='source_data' AND ti.`value` IN ( |
| 213 | + SELECT `data` FROM data_quality |
| 214 | + WHERE `quality`='{quality}' AND {value} |
| 215 | + ) |
| 216 | + ) |
| 217 | + """ # noqa: S608 |
| 218 | + |
| 219 | + |
| 220 | +@router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.") |
| 221 | +@router.get(path="/list") |
| 222 | +async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915 |
| 223 | + pagination: Annotated[Pagination, Body(default_factory=Pagination)], |
| 224 | + task_type_id: Annotated[int | None, Body(description="Filter by task type id.")] = None, |
| 225 | + tag: Annotated[str | None, SystemString64] = None, |
| 226 | + data_tag: Annotated[str | None, SystemString64] = None, |
| 227 | + status: Annotated[TaskStatusFilter, Body()] = TaskStatusFilter.ACTIVE, |
| 228 | + task_id: Annotated[ |
| 229 | + list[int] | None, |
| 230 | + Body(description="Filter by task id(s).", min_length=1), |
| 231 | + ] = None, |
| 232 | + data_id: Annotated[ |
| 233 | + list[int] | None, |
| 234 | + Body(description="Filter by dataset id(s).", min_length=1), |
| 235 | + ] = None, |
| 236 | + data_name: Annotated[str | None, CasualString128] = None, |
| 237 | + number_instances: Annotated[str | None, IntegerRange] = None, |
| 238 | + number_features: Annotated[str | None, IntegerRange] = None, |
| 239 | + number_classes: Annotated[str | None, IntegerRange] = None, |
| 240 | + number_missing_values: Annotated[str | None, IntegerRange] = None, |
| 241 | + expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, |
| 242 | +) -> list[dict[str, Any]]: |
| 243 | + """List tasks, optionally filtered by type, tag, status, dataset properties, and more.""" |
| 244 | + assert expdb is not None # noqa: S101 |
| 245 | + |
| 246 | + clauses: list[str] = [] |
| 247 | + parameters: dict[str, Any] = { |
| 248 | + "offset": max(0, pagination.offset), |
| 249 | + "limit": max(0, pagination.limit), |
| 250 | + } |
| 251 | + |
| 252 | + if status != TaskStatusFilter.ALL: |
| 253 | + clauses.append("AND IFNULL(ds.`status`, 'in_preparation') = :status") |
| 254 | + parameters["status"] = status |
| 255 | + |
| 256 | + if task_type_id is not None: |
| 257 | + clauses.append("AND t.`ttid` = :task_type_id") |
| 258 | + parameters["task_type_id"] = task_type_id |
| 259 | + |
| 260 | + if tag is not None: |
| 261 | + clauses.append("AND t.`task_id` IN (SELECT `id` FROM task_tag WHERE `tag` = :tag)") |
| 262 | + parameters["tag"] = tag |
| 263 | + |
| 264 | + if data_tag is not None: |
| 265 | + clauses.append("AND d.`did` IN (SELECT `id` FROM dataset_tag WHERE `tag` = :data_tag)") |
| 266 | + parameters["data_tag"] = data_tag |
| 267 | + |
| 268 | + if data_name is not None: |
| 269 | + clauses.append("AND d.`name` = :data_name") |
| 270 | + parameters["data_name"] = data_name |
| 271 | + |
| 272 | + if task_id is not None: |
| 273 | + clauses.append("AND t.`task_id` IN :task_ids") |
| 274 | + parameters["task_ids"] = task_id |
| 275 | + |
| 276 | + if data_id is not None: |
| 277 | + clauses.append("AND d.`did` IN :data_ids") |
| 278 | + parameters["data_ids"] = data_id |
| 279 | + |
| 280 | + where_number_instances = _quality_clause("NumberOfInstances", number_instances) |
| 281 | + where_number_features = _quality_clause("NumberOfFeatures", number_features) |
| 282 | + where_number_classes = _quality_clause("NumberOfClasses", number_classes) |
| 283 | + where_number_missing_values = _quality_clause("NumberOfMissingValues", number_missing_values) |
| 284 | + |
| 285 | + # subquery to get the latest status per dataset (dataset_status is a history table) |
| 286 | + status_subquery = """ |
| 287 | + SELECT ds1.did, ds1.status |
| 288 | + FROM dataset_status ds1 |
| 289 | + WHERE ds1.status_date = ( |
| 290 | + SELECT MAX(ds2.status_date) FROM dataset_status ds2 |
| 291 | + WHERE ds1.did = ds2.did |
| 292 | + ) |
| 293 | + """ |
| 294 | + |
| 295 | + main_query = text( |
| 296 | + f""" |
| 297 | + SELECT |
| 298 | + t.`task_id`, |
| 299 | + t.`ttid` AS task_type_id, |
| 300 | + tt.`name` AS task_type, |
| 301 | + d.`did`, |
| 302 | + d.`name`, |
| 303 | + d.`format`, |
| 304 | + IFNULL(ds.`status`, 'in_preparation') AS status |
| 305 | + FROM task t |
| 306 | + JOIN task_type tt |
| 307 | + ON tt.`ttid` = t.`ttid` |
| 308 | + JOIN task_inputs ti_source |
| 309 | + ON ti_source.`task_id` = t.`task_id` |
| 310 | + AND ti_source.`input` = 'source_data' |
| 311 | + JOIN dataset d |
| 312 | + ON d.`did` = ti_source.`value` |
| 313 | + LEFT JOIN ({status_subquery}) ds |
| 314 | + ON ds.`did` = d.`did` |
| 315 | + WHERE 1=1 |
| 316 | + {where_number_instances} |
| 317 | + {where_number_features} |
| 318 | + {where_number_classes} |
| 319 | + {where_number_missing_values} |
| 320 | + {" ".join(clauses)} |
| 321 | + GROUP BY t.`task_id`, t.`ttid`, tt.`name`, d.`did`, d.`name`, d.`format`, ds.`status` |
| 322 | + ORDER BY t.`task_id` |
| 323 | + LIMIT :limit OFFSET :offset |
| 324 | + """, # noqa: S608 |
| 325 | + ) |
| 326 | + |
| 327 | + if task_id is not None: |
| 328 | + main_query = main_query.bindparams(bindparam("task_ids", expanding=True)) |
| 329 | + if data_id is not None: |
| 330 | + main_query = main_query.bindparams(bindparam("data_ids", expanding=True)) |
| 331 | + |
| 332 | + result = await expdb.execute(main_query, parameters=parameters) |
| 333 | + rows = result.mappings().all() |
| 334 | + |
| 335 | + if not rows: |
| 336 | + msg = "No tasks match the search criteria." |
| 337 | + raise NoResultsError(msg, code="482") |
| 338 | + |
| 339 | + columns = ["task_id", "task_type_id", "task_type", "did", "name", "format", "status"] |
| 340 | + tasks: dict[int, dict[str, Any]] = { |
| 341 | + row["task_id"]: {col: row[col] for col in columns} for row in rows |
| 342 | + } |
| 343 | + task_ids: list[int] = list(tasks.keys()) |
| 344 | + dataset_ids: list[int] = list({t["did"] for t in tasks.values()}) |
| 345 | + |
| 346 | + inputs_query = text( |
| 347 | + """ |
| 348 | + SELECT `task_id`, `input`, `value` |
| 349 | + FROM task_inputs |
| 350 | + WHERE `task_id` IN :task_ids |
| 351 | + AND `input` IN :basic_inputs |
| 352 | + """, |
| 353 | + ).bindparams( |
| 354 | + bindparam("task_ids", expanding=True), |
| 355 | + bindparam("basic_inputs", expanding=True), |
| 356 | + ) |
| 357 | + qualities_query = text( |
| 358 | + """ |
| 359 | + SELECT `data`, `quality`, `value` |
| 360 | + FROM data_quality |
| 361 | + WHERE `data` IN :dataset_ids |
| 362 | + AND `quality` IN :quality_names |
| 363 | + """, |
| 364 | + ).bindparams( |
| 365 | + bindparam("dataset_ids", expanding=True), |
| 366 | + bindparam("quality_names", expanding=True), |
| 367 | + ) |
| 368 | + |
| 369 | + tags_query = text( |
| 370 | + """ |
| 371 | + SELECT `id`, `tag` |
| 372 | + FROM task_tag |
| 373 | + WHERE `id` IN :task_ids |
| 374 | + """, |
| 375 | + ).bindparams(bindparam("task_ids", expanding=True)) |
| 376 | + |
| 377 | + inputs_result, qualities_result, tags_result = await asyncio.gather( |
| 378 | + expdb.execute( |
| 379 | + inputs_query, |
| 380 | + parameters={"task_ids": task_ids, "basic_inputs": BASIC_TASK_INPUTS}, |
| 381 | + ), |
| 382 | + expdb.execute( |
| 383 | + qualities_query, |
| 384 | + parameters={"dataset_ids": dataset_ids, "quality_names": QUALITIES_TO_SHOW}, |
| 385 | + ), |
| 386 | + expdb.execute( |
| 387 | + tags_query, |
| 388 | + parameters={"task_ids": task_ids}, |
| 389 | + ), |
| 390 | + ) |
| 391 | + |
| 392 | + for row in inputs_result.all(): |
| 393 | + tasks[row.task_id].setdefault("input", []).append( |
| 394 | + {"name": row.input, "value": row.value}, |
| 395 | + ) |
| 396 | + |
| 397 | + # multiple tasks can reference the same dataset; map dataset_id -> [task_id, ...] |
| 398 | + did_to_task_ids: dict[int, list[int]] = {} |
| 399 | + for tid, t in tasks.items(): |
| 400 | + did_to_task_ids.setdefault(t["did"], []).append(tid) |
| 401 | + for row in qualities_result.all(): |
| 402 | + for tid in did_to_task_ids.get(row.data, []): |
| 403 | + tasks[tid].setdefault("quality", []).append( |
| 404 | + {"name": row.quality, "value": str(row.value)}, |
| 405 | + ) |
| 406 | + |
| 407 | + for row in tags_result.all(): |
| 408 | + tasks[row.id].setdefault("tag", []).append(row.tag) |
| 409 | + |
| 410 | + return list(tasks.values()) |
| 411 | + |
| 412 | + |
161 | 413 | @router.get("/{task_id}") |
162 | 414 | async def get_task( |
163 | 415 | task_id: int, |
|
0 commit comments