-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtyping.py
More file actions
275 lines (216 loc) · 8.37 KB
/
typing.py
File metadata and controls
275 lines (216 loc) · 8.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
import keyword
from dataclasses import dataclass
from typing import NewType, assert_never, cast
SPECIAL_CHARS = [".", "-", ":", "/", "@", " ", "$", "!", "?", "=", "&", "|", "~", "`"]
ModuleName = NewType("ModuleName", str)
ClassName = NewType("ClassName", str)
FileContents = NewType("FileContents", str)
HandshakeType = NewType("HandshakeType", str)
RenderedPath = NewType("RenderedPath", str)
@dataclass(frozen=True)
class TypeName:
value: str
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")
def __eq__(self, other: object) -> bool:
return isinstance(other, TypeName) and other.value == self.value
def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)
@dataclass(frozen=True)
class LiteralType:
value: str
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")
def __eq__(self, other: object) -> bool:
return isinstance(other, LiteralType) and other.value == self.value
def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)
@dataclass(frozen=True)
class NoneTypeExpr:
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")
def __eq__(self, other: object) -> bool:
return isinstance(other, NoneTypeExpr)
def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)
@dataclass(frozen=True)
class DictTypeExpr:
nested: "TypeExpression"
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")
def __eq__(self, other: object) -> bool:
return isinstance(other, DictTypeExpr) and other.nested == self.nested
def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)
@dataclass(frozen=True)
class ListTypeExpr:
nested: "TypeExpression"
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")
def __eq__(self, other: object) -> bool:
return isinstance(other, ListTypeExpr) and other.nested == self.nested
def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)
@dataclass(frozen=True)
class LiteralTypeExpr:
nested: int | str
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")
def __eq__(self, other: object) -> bool:
return isinstance(other, LiteralTypeExpr) and other.nested == self.nested
def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)
@dataclass(frozen=True)
class UnionTypeExpr:
nested: list["TypeExpression"]
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")
def __eq__(self, other: object) -> bool:
return isinstance(other, UnionTypeExpr) and set(other.nested) == set(
self.nested
)
def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)
@dataclass(frozen=True)
class OpenUnionTypeExpr:
union: UnionTypeExpr
fallback_type: str
validator_function: str
def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")
def __eq__(self, other: object) -> bool:
return isinstance(other, OpenUnionTypeExpr) and other.union == self.union
def __lt__(self, other: object) -> bool:
return hash(self) < hash(other)
TypeExpression = (
TypeName
| LiteralType
| NoneTypeExpr
| DictTypeExpr
| ListTypeExpr
| LiteralTypeExpr
| UnionTypeExpr
| OpenUnionTypeExpr
)
def _flatten_nested_unions(value: TypeExpression) -> TypeExpression:
def work(
value: TypeExpression,
) -> tuple[list[TypeExpression], TypeExpression | None]:
match value:
case UnionTypeExpr(inner):
flattened: list[TypeExpression] = []
for tpe in inner:
_union, _nonunion = work(tpe)
flattened.extend(_union)
if _nonunion is not None:
flattened.append(_nonunion)
return (flattened, None)
case other:
return ([], other)
_inner, nonunion = work(value)
if nonunion and not _inner:
return nonunion
elif _inner and nonunion is None:
return UnionTypeExpr(_inner)
else:
raise ValueError("Incoherent state when trying to flatten unions")
def normalize_special_chars(value: str) -> str:
for char in SPECIAL_CHARS:
value = value.replace(char, "_")
value = value.lstrip("_")
# Append underscore to Python keywords (e.g., "from" -> "from_")
if keyword.iskeyword(value):
value = value + "_"
return value
def render_type_expr(value: TypeExpression) -> str:
match _flatten_nested_unions(value):
case DictTypeExpr(nested):
return f"dict[str, {render_type_expr(nested)}]"
case ListTypeExpr(nested):
return f"list[{render_type_expr(nested)}]"
case LiteralTypeExpr(inner):
return f"Literal[{repr(inner)}]"
case UnionTypeExpr(inner):
literals: list[LiteralTypeExpr] = []
_other: list[TypeExpression] = []
for tpe in inner:
if isinstance(tpe, UnionTypeExpr):
raise ValueError("These should have been flattened")
elif isinstance(tpe, LiteralTypeExpr):
literals.append(tpe)
else:
_other.append(tpe)
without_none: list[TypeExpression] = [
x for x in _other if not isinstance(x, NoneTypeExpr)
]
has_none = len(_other) > len(without_none)
_other = without_none
retval: str = " | ".join(render_type_expr(x) for x in _other)
if literals:
_rendered: str = ", ".join(repr(x.nested) for x in literals)
if retval:
retval = f"Literal[{_rendered}] | {retval}"
else:
retval = f"Literal[{_rendered}]"
if has_none:
if retval:
retval = f"{retval} | None"
else:
retval = "None"
return retval
case OpenUnionTypeExpr(inner):
open_union = cast(OpenUnionTypeExpr, value)
return (
"Annotated["
f"{render_type_expr(inner)} | {open_union.fallback_type},"
f"WrapValidator({open_union.validator_function})"
"]"
)
case TypeName(name):
return normalize_special_chars(name)
case LiteralType(literal_value):
return literal_value
case NoneTypeExpr():
return "None"
case other:
assert_never(other)
def render_literal_type(value: TypeExpression) -> str:
return render_type_expr(ensure_literal_type(value))
def extract_inner_type(value: TypeExpression) -> TypeName:
match value:
case DictTypeExpr(nested):
return extract_inner_type(nested)
case ListTypeExpr(nested):
return extract_inner_type(nested)
case LiteralTypeExpr(_):
raise ValueError(f"Unexpected literal type: {repr(value)}")
case UnionTypeExpr(_):
raise ValueError(
"Attempting to extract from a union, "
f"currently not possible: {repr(value)}"
)
case OpenUnionTypeExpr(_):
raise ValueError(
"Attempting to extract from a union, "
f"currently not possible: {repr(value)}"
)
case TypeName(name):
return TypeName(name)
case LiteralType(name):
raise ValueError(
f"Attempting to extract from a literal type: {repr(value)}"
)
case NoneTypeExpr():
raise ValueError(
f"Attempting to extract from a literal 'None': {repr(value)}",
)
case other:
assert_never(other)
def ensure_literal_type(value: TypeExpression) -> TypeName:
match value:
case TypeName(name):
return TypeName(name)
case other:
raise ValueError(
f"Unexpected expression when expecting a type name: {repr(other)}"
)