|
32 | 32 | from sqlalchemy import and_, or_ |
33 | 33 |
|
34 | 34 | if TYPE_CHECKING: |
35 | | - from sqlalchemy.sql import ColumnElement, Select |
| 35 | + from sqlalchemy.sql import Select |
36 | 36 |
|
37 | 37 | from airflow.api_fastapi.common.parameters import SortParam |
38 | 38 |
|
39 | 39 |
|
| 40 | +def _encode_value(val: Any) -> dict[str, Any]: |
| 41 | + """Encode a single Python value as a typed {"type": ..., "value": ...} object.""" |
| 42 | + if val is None: |
| 43 | + return {"type": "null", "value": None} |
| 44 | + if isinstance(val, uuid_mod.UUID): |
| 45 | + return {"type": "uuid", "value": str(val)} |
| 46 | + if isinstance(val, datetime): |
| 47 | + return {"type": "datetime", "value": val.isoformat()} |
| 48 | + if isinstance(val, int): |
| 49 | + return {"type": "int", "value": val} |
| 50 | + return {"type": "str", "value": str(val)} |
| 51 | + |
| 52 | + |
| 53 | +def _decode_value(entry: dict[str, Any]) -> Any: |
| 54 | + """Decode a typed cursor entry back to its Python value.""" |
| 55 | + type_tag = entry["type"] |
| 56 | + raw = entry["value"] |
| 57 | + if type_tag == "null": |
| 58 | + return None |
| 59 | + if type_tag == "uuid": |
| 60 | + return uuid_mod.UUID(str(raw)) |
| 61 | + if type_tag == "datetime": |
| 62 | + return datetime.fromisoformat(str(raw)) |
| 63 | + if type_tag == "int": |
| 64 | + return int(raw) |
| 65 | + return str(raw) |
| 66 | + |
| 67 | + |
40 | 68 | def encode_cursor(row: Any, sort_param: SortParam) -> str: |
41 | 69 | """ |
42 | 70 | Encode cursor token from the last row of a result set. |
43 | 71 |
|
44 | | - The token is a base64url-encoded JSON list containing the sort column |
45 | | - values in the same order as the resolved sort columns. |
| 72 | + The token is a base64url-encoded JSON list of typed objects, each |
| 73 | + containing ``{"type": "<tag>", "value": <serialized>}`` so the |
| 74 | + cursor is self-describing and can be decoded without column metadata. |
46 | 75 | """ |
47 | 76 | resolved = sort_param.get_resolved_columns() |
48 | 77 | if not resolved: |
49 | 78 | raise ValueError("SortParam has no resolved columns.") |
50 | 79 |
|
51 | | - values: list[Any] = [] |
52 | | - for attr_name, _col, _desc in resolved: |
53 | | - val = getattr(row, attr_name, None) |
54 | | - if val is None: |
55 | | - values.append(None) |
56 | | - elif isinstance(val, datetime): |
57 | | - values.append(val.isoformat()) |
58 | | - else: |
59 | | - values.append(str(val)) |
60 | | - |
61 | | - return base64.urlsafe_b64encode(json.dumps(values).encode()).decode() |
| 80 | + entries = [_encode_value(getattr(row, attr_name, None)) for attr_name, _col, _desc in resolved] |
| 81 | + return base64.urlsafe_b64encode(json.dumps(entries).encode()).decode() |
62 | 82 |
|
63 | 83 |
|
64 | | -def decode_cursor(token: str) -> list[Any]: |
65 | | - """Decode a cursor token and return the list of values.""" |
| 84 | +def decode_cursor(token: str) -> list[dict[str, Any]]: |
| 85 | + """Decode a cursor token and return the list of typed value entries.""" |
66 | 86 | try: |
67 | 87 | data = json.loads(base64.urlsafe_b64decode(token)) |
68 | 88 | except Exception: |
69 | 89 | raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid cursor token") |
70 | 90 |
|
71 | | - if not isinstance(data, list): |
| 91 | + if not isinstance(data, list) or any( |
| 92 | + not isinstance(entry, dict) or "type" not in entry or "value" not in entry for entry in data |
| 93 | + ): |
72 | 94 | raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid cursor token structure") |
73 | 95 |
|
74 | 96 | return data |
75 | 97 |
|
76 | 98 |
|
77 | | -def _coerce_cursor_value(raw: Any, col: ColumnElement) -> Any: |
78 | | - """Convert a JSON-serialized cursor value to the Python type expected by the column.""" |
79 | | - if raw is None: |
80 | | - return None |
81 | | - |
82 | | - from sqlalchemy import Integer, String |
83 | | - from sqlalchemy.sql.sqltypes import Uuid |
84 | | - |
85 | | - col_type = getattr(col, "type", None) |
86 | | - if col_type is None: |
87 | | - return raw |
88 | | - |
89 | | - if isinstance(col_type, Uuid): |
90 | | - return uuid_mod.UUID(str(raw)) |
91 | | - if isinstance(col_type, Integer): |
92 | | - return int(raw) |
93 | | - if isinstance(col_type, String): |
94 | | - return str(raw) |
95 | | - |
96 | | - type_name = type(col_type).__name__.lower() |
97 | | - if "datetime" in type_name or "timestamp" in type_name or "date" in type_name: |
98 | | - return datetime.fromisoformat(str(raw)) |
99 | | - |
100 | | - return raw |
101 | | - |
102 | | - |
103 | 99 | def apply_cursor_filter(statement: Select, cursor: str, sort_param: SortParam) -> Select: |
104 | 100 | """ |
105 | 101 | Apply a keyset pagination WHERE clause from a cursor token. |
106 | 102 |
|
107 | 103 | Builds a composite comparison that respects mixed ASC/DESC ordering |
108 | 104 | on the resolved sort columns. |
109 | 105 | """ |
110 | | - cursor_values = decode_cursor(cursor) |
| 106 | + cursor_entries = decode_cursor(cursor) |
111 | 107 |
|
112 | 108 | resolved = sort_param.get_resolved_columns() |
113 | | - if len(cursor_values) != len(resolved): |
| 109 | + if len(cursor_entries) != len(resolved): |
114 | 110 | raise HTTPException(status.HTTP_400_BAD_REQUEST, "Cursor token does not match current query shape") |
115 | 111 |
|
116 | | - parsed_values: list[Any] = [] |
117 | | - for i, (_name, col, _desc) in enumerate(resolved): |
118 | | - parsed_values.append(_coerce_cursor_value(cursor_values[i], col)) |
| 112 | + parsed_values = [_decode_value(entry) for entry in cursor_entries] |
119 | 113 |
|
120 | 114 | # Build the keyset WHERE clause for mixed ASC/DESC ordering. |
121 | 115 | # For columns (c1 ASC, c2 DESC, c3 ASC) with cursor values (v1, v2, v3): |
|
0 commit comments