Skip to content

Commit d11e2a5

Browse files
Adjustments
1 parent edcefc1 commit d11e2a5

3 files changed

Lines changed: 73 additions & 56 deletions

File tree

airflow-core/src/airflow/api_fastapi/common/cursors.py

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -32,90 +32,84 @@
3232
from sqlalchemy import and_, or_
3333

3434
if TYPE_CHECKING:
35-
from sqlalchemy.sql import ColumnElement, Select
35+
from sqlalchemy.sql import Select
3636

3737
from airflow.api_fastapi.common.parameters import SortParam
3838

3939

40+
def _encode_value(val: Any) -> dict[str, Any]:
41+
"""Encode a single Python value as a typed {"type": ..., "value": ...} object."""
42+
if val is None:
43+
return {"type": "null", "value": None}
44+
if isinstance(val, uuid_mod.UUID):
45+
return {"type": "uuid", "value": str(val)}
46+
if isinstance(val, datetime):
47+
return {"type": "datetime", "value": val.isoformat()}
48+
if isinstance(val, int):
49+
return {"type": "int", "value": val}
50+
return {"type": "str", "value": str(val)}
51+
52+
53+
def _decode_value(entry: dict[str, Any]) -> Any:
54+
"""Decode a typed cursor entry back to its Python value."""
55+
type_tag = entry["type"]
56+
raw = entry["value"]
57+
if type_tag == "null":
58+
return None
59+
if type_tag == "uuid":
60+
return uuid_mod.UUID(str(raw))
61+
if type_tag == "datetime":
62+
return datetime.fromisoformat(str(raw))
63+
if type_tag == "int":
64+
return int(raw)
65+
return str(raw)
66+
67+
4068
def encode_cursor(row: Any, sort_param: SortParam) -> str:
4169
"""
4270
Encode cursor token from the last row of a result set.
4371
44-
The token is a base64url-encoded JSON list containing the sort column
45-
values in the same order as the resolved sort columns.
72+
The token is a base64url-encoded JSON list of typed objects, each
73+
containing ``{"type": "<tag>", "value": <serialized>}`` so the
74+
cursor is self-describing and can be decoded without column metadata.
4675
"""
4776
resolved = sort_param.get_resolved_columns()
4877
if not resolved:
4978
raise ValueError("SortParam has no resolved columns.")
5079

51-
values: list[Any] = []
52-
for attr_name, _col, _desc in resolved:
53-
val = getattr(row, attr_name, None)
54-
if val is None:
55-
values.append(None)
56-
elif isinstance(val, datetime):
57-
values.append(val.isoformat())
58-
else:
59-
values.append(str(val))
60-
61-
return base64.urlsafe_b64encode(json.dumps(values).encode()).decode()
80+
entries = [_encode_value(getattr(row, attr_name, None)) for attr_name, _col, _desc in resolved]
81+
return base64.urlsafe_b64encode(json.dumps(entries).encode()).decode()
6282

6383

64-
def decode_cursor(token: str) -> list[Any]:
65-
"""Decode a cursor token and return the list of values."""
84+
def decode_cursor(token: str) -> list[dict[str, Any]]:
85+
"""Decode a cursor token and return the list of typed value entries."""
6686
try:
6787
data = json.loads(base64.urlsafe_b64decode(token))
6888
except Exception:
6989
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid cursor token")
7090

71-
if not isinstance(data, list):
91+
if not isinstance(data, list) or any(
92+
not isinstance(entry, dict) or "type" not in entry or "value" not in entry for entry in data
93+
):
7294
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid cursor token structure")
7395

7496
return data
7597

7698

77-
def _coerce_cursor_value(raw: Any, col: ColumnElement) -> Any:
78-
"""Convert a JSON-serialized cursor value to the Python type expected by the column."""
79-
if raw is None:
80-
return None
81-
82-
from sqlalchemy import Integer, String
83-
from sqlalchemy.sql.sqltypes import Uuid
84-
85-
col_type = getattr(col, "type", None)
86-
if col_type is None:
87-
return raw
88-
89-
if isinstance(col_type, Uuid):
90-
return uuid_mod.UUID(str(raw))
91-
if isinstance(col_type, Integer):
92-
return int(raw)
93-
if isinstance(col_type, String):
94-
return str(raw)
95-
96-
type_name = type(col_type).__name__.lower()
97-
if "datetime" in type_name or "timestamp" in type_name or "date" in type_name:
98-
return datetime.fromisoformat(str(raw))
99-
100-
return raw
101-
102-
10399
def apply_cursor_filter(statement: Select, cursor: str, sort_param: SortParam) -> Select:
104100
"""
105101
Apply a keyset pagination WHERE clause from a cursor token.
106102
107103
Builds a composite comparison that respects mixed ASC/DESC ordering
108104
on the resolved sort columns.
109105
"""
110-
cursor_values = decode_cursor(cursor)
106+
cursor_entries = decode_cursor(cursor)
111107

112108
resolved = sort_param.get_resolved_columns()
113-
if len(cursor_values) != len(resolved):
109+
if len(cursor_entries) != len(resolved):
114110
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Cursor token does not match current query shape")
115111

116-
parsed_values: list[Any] = []
117-
for i, (_name, col, _desc) in enumerate(resolved):
118-
parsed_values.append(_coerce_cursor_value(cursor_values[i], col))
112+
parsed_values = [_decode_value(entry) for entry in cursor_entries]
119113

120114
# Build the keyset WHERE clause for mixed ASC/DESC ordering.
121115
# For columns (c1 ASC, c2 DESC, c3 ASC) with cursor values (v1, v2, v3):

airflow-core/src/airflow/api_fastapi/core_api/datamodels/common.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,9 +158,6 @@ class BulkResponse(BaseModel):
158158
)
159159

160160

161-
# Common Pagination Base Models
162-
163-
164161
class OffsetPaginatedResponse(BaseModel):
165162
"""Base for offset-paginated collection responses."""
166163

airflow-core/tests/unit/api_fastapi/common/test_cursors.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def test_encode_decode_cursor_roundtrip(self):
4949
token = encode_cursor(row, sp)
5050
decoded = decode_cursor(token)
5151

52-
assert decoded == ["2024-01-15T10:00:00+00:00", "019462ab-1234-5678-9abc-def012345678"]
52+
assert decoded == [
53+
{"type": "str", "value": "2024-01-15T10:00:00+00:00"},
54+
{"type": "str", "value": "019462ab-1234-5678-9abc-def012345678"},
55+
]
5356

5457
def test_decode_cursor_invalid_base64(self):
5558
with pytest.raises(HTTPException, match="Invalid cursor token"):
@@ -65,6 +68,21 @@ def test_decode_cursor_not_a_list(self):
6568
with pytest.raises(HTTPException, match="Invalid cursor token structure"):
6669
decode_cursor(token)
6770

71+
def test_decode_cursor_missing_type_key(self):
72+
token = base64.urlsafe_b64encode(json.dumps([{"value": "foo"}]).encode()).decode()
73+
with pytest.raises(HTTPException, match="Invalid cursor token structure"):
74+
decode_cursor(token)
75+
76+
def test_decode_cursor_missing_value_key(self):
77+
token = base64.urlsafe_b64encode(json.dumps([{"type": "str"}]).encode()).decode()
78+
with pytest.raises(HTTPException, match="Invalid cursor token structure"):
79+
decode_cursor(token)
80+
81+
def test_decode_cursor_entry_not_a_dict(self):
82+
token = base64.urlsafe_b64encode(json.dumps(["just-a-string"]).encode()).decode()
83+
with pytest.raises(HTTPException, match="Invalid cursor token structure"):
84+
decode_cursor(token)
85+
6886
def test_encode_cursor_works_without_prior_to_orm(self):
6987
"""get_resolved_columns now lazily resolves, so to_orm is no longer required before encode."""
7088
sp = SortParam(["id"], TaskInstance)
@@ -73,18 +91,23 @@ def test_encode_cursor_works_without_prior_to_orm(self):
7391
row.id = "019462ab-1234-5678-9abc-def012345678"
7492
token = encode_cursor(row, sp)
7593
decoded = decode_cursor(token)
76-
assert decoded == ["019462ab-1234-5678-9abc-def012345678"]
94+
assert decoded == [{"type": "str", "value": "019462ab-1234-5678-9abc-def012345678"}]
7795

7896
def test_apply_cursor_filter_wrong_value_count(self):
7997
sp = self._make_sort_param_with_resolved_columns(["start_date"])
80-
token = base64.urlsafe_b64encode(json.dumps(["only-one-value"]).encode()).decode()
98+
token = base64.urlsafe_b64encode(
99+
json.dumps([{"type": "str", "value": "only-one-value"}]).encode()
100+
).decode()
81101

82102
with pytest.raises(HTTPException, match="does not match"):
83103
apply_cursor_filter(select(TaskInstance), token, sp)
84104

85105
def test_apply_cursor_filter_ascending(self):
86106
sp = self._make_sort_param_with_resolved_columns(["start_date"])
87-
values = ["2024-01-15T10:00:00", "019462ab-1234-5678-9abc-def012345678"]
107+
values = [
108+
{"type": "datetime", "value": "2024-01-15T10:00:00"},
109+
{"type": "uuid", "value": "019462ab-1234-5678-9abc-def012345678"},
110+
]
88111
token = base64.urlsafe_b64encode(json.dumps(values).encode()).decode()
89112

90113
stmt = apply_cursor_filter(select(TaskInstance), token, sp)
@@ -93,7 +116,10 @@ def test_apply_cursor_filter_ascending(self):
93116

94117
def test_apply_cursor_filter_descending(self):
95118
sp = self._make_sort_param_with_resolved_columns(["-start_date"])
96-
values = ["2024-01-15T10:00:00", "019462ab-1234-5678-9abc-def012345678"]
119+
values = [
120+
{"type": "datetime", "value": "2024-01-15T10:00:00"},
121+
{"type": "uuid", "value": "019462ab-1234-5678-9abc-def012345678"},
122+
]
97123
token = base64.urlsafe_b64encode(json.dumps(values).encode()).decode()
98124

99125
stmt = apply_cursor_filter(select(TaskInstance), token, sp)

0 commit comments

Comments
 (0)