Skip to content

Commit 41647dc

Browse files
feat/exclude_none=True for TypedDict encoder (#96)
Why === I missed `exclude_none=True` semantics for TypedDict, so let's add that in there. What changed ============ - Add a structured `oneOf` test - Add a `deep_equals` method - Remove `cast` usage in `parity` - Add "exclude_none" functionality. This should make TypeDict exactly equal to pydantic for inputs. Test plan ========= Tests included
1 parent ca9d552 commit 41647dc

3 files changed

Lines changed: 151 additions & 13 deletions

File tree

replit_river/codegen/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,12 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
420420
current_chunks.append(f" {name}: {type_name}")
421421
typeddict_encoder.append(",")
422422
typeddict_encoder.append("}")
423+
# exclude_none
424+
typeddict_encoder = (
425+
["{k: v for (k, v) in ("]
426+
+ typeddict_encoder
427+
+ [").items() if v is not None}"]
428+
)
423429
else:
424430
typeddict_encoder.append("{}")
425431
current_chunks.append(" pass")

scripts/parity/check_parity.py

Lines changed: 140 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Literal, TypedDict, TypeVar, Union, cast
1+
from typing import Any, Callable, Literal, TypedDict, TypeVar, Union
22

33
import pyd
44
import tyd
@@ -16,15 +16,45 @@
1616

1717
A = TypeVar("A")
1818

19+
PrimitiveType = (
20+
bool | str | int | float | dict[str, "PrimitiveType"] | list["PrimitiveType"]
21+
)
22+
23+
24+
def deep_equal(a: PrimitiveType, b: PrimitiveType) -> Literal[True]:
25+
if a == b:
26+
return True
27+
elif isinstance(a, dict) and isinstance(b, dict):
28+
a_keys: PrimitiveType = list(a.keys())
29+
b_keys: PrimitiveType = list(b.keys())
30+
assert deep_equal(a_keys, b_keys)
31+
32+
# We do this dance again because Python variance is hard. Feel free to fix it.
33+
keys = set(a.keys())
34+
keys.update(b.keys())
35+
for k in keys:
36+
aa: PrimitiveType = a[k]
37+
bb: PrimitiveType = b[k]
38+
assert deep_equal(aa, bb)
39+
return True
40+
elif isinstance(a, list) and isinstance(b, list):
41+
assert len(a) == len(b)
42+
for i in range(len(a)):
43+
assert deep_equal(a[i], b[i])
44+
return True
45+
else:
46+
assert a == b, f"{a} != {b}"
47+
return True
48+
1949

2050
def baseTestPattern(
2151
x: A, encode: Callable[[A], Any], adapter: TypeAdapter[Any]
2252
) -> None:
2353
a = encode(x)
2454
m = adapter.validate_python(a)
25-
z = adapter.dump_python(m)
55+
z = adapter.dump_python(m, by_alias=True, exclude_none=True)
2656

27-
assert a == z
57+
assert deep_equal(a, z)
2858

2959

3060
def testAiexecExecInit() -> None:
@@ -93,7 +123,39 @@ def testAgenttoollanguageserverGetcodesymbolInput() -> None:
93123
"line": gen_float(),
94124
"character": gen_float(),
95125
},
96-
"kind": cast(kind_type, gen_opt(gen_choice(list(range(1, 27))))()),
126+
"kind": gen_choice(
127+
list[kind_type](
128+
[
129+
1,
130+
2,
131+
3,
132+
4,
133+
5,
134+
6,
135+
7,
136+
8,
137+
9,
138+
10,
139+
11,
140+
12,
141+
13,
142+
14,
143+
15,
144+
16,
145+
17,
146+
18,
147+
19,
148+
20,
149+
21,
150+
22,
151+
23,
152+
24,
153+
25,
154+
26,
155+
None,
156+
]
157+
)
158+
)(),
97159
}
98160

99161
baseTestPattern(
@@ -116,17 +178,17 @@ def testShellexecSpawnInput() -> None:
116178
"env": gen_opt(gen_dict(gen_str))(),
117179
"cwd": gen_opt(gen_str)(),
118180
"size": gen_opt(
119-
lambda: cast(
120-
size_type,
181+
lambda: size_type(
121182
{
122183
"rows": gen_int(),
123184
"cols": gen_int(),
124-
},
125-
)
185+
}
186+
),
126187
)(),
127188
"useReplitRunEnv": gen_opt(gen_bool)(),
128189
"useCgroupMagic": gen_opt(gen_bool)(),
129190
"interactive": gen_opt(gen_bool)(),
191+
"onlySpawnIfNoProcesses": gen_opt(gen_bool)(),
130192
}
131193

132194
baseTestPattern(
@@ -146,12 +208,82 @@ def testConmanfilesystemPersistInput() -> None:
146208
)
147209

148210

211+
closeFile = tyd.ReplspaceapiInitInputOneOf_closeFile
212+
githubToken = tyd.ReplspaceapiInitInputOneOf_githubToken
213+
sshToken0 = tyd.ReplspaceapiInitInputOneOf_sshToken0
214+
sshToken1 = tyd.ReplspaceapiInitInputOneOf_sshToken1
215+
allowDefaultBucketAccess = tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccess
216+
217+
allowDefaultBucketAccessResultOk = (
218+
tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccessResultOneOf_ok
219+
)
220+
allowDefaultBucketAccessResultError = (
221+
tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccessResultOneOf_error
222+
)
223+
224+
225+
def testReplspaceapiInitInput() -> None:
226+
x: tyd.ReplspaceapiInitInput = gen_choice(
227+
list[tyd.ReplspaceapiInitInput](
228+
[
229+
closeFile(
230+
{"kind": "closeFile", "filename": gen_str(), "nonce": gen_str()}
231+
),
232+
githubToken(
233+
{"kind": "githubToken", "token": gen_str(), "nonce": gen_str()}
234+
),
235+
sshToken0(
236+
{
237+
"kind": "sshToken",
238+
"nonce": gen_str(),
239+
"SSHHostname": gen_str(),
240+
"token": gen_str(),
241+
}
242+
),
243+
sshToken1({"kind": "sshToken", "nonce": gen_str(), "error": gen_str()}),
244+
allowDefaultBucketAccess(
245+
{
246+
"kind": "allowDefaultBucketAccess",
247+
"nonce": gen_str(),
248+
"result": gen_choice(
249+
list[
250+
tyd.ReplspaceapiInitInputOneOf_allowDefaultBucketAccessResult
251+
](
252+
[
253+
allowDefaultBucketAccessResultOk(
254+
{
255+
"bucketId": gen_str(),
256+
"sourceReplId": gen_str(),
257+
"status": "ok",
258+
"targetReplId": gen_str(),
259+
}
260+
),
261+
allowDefaultBucketAccessResultError(
262+
{"message": gen_str(), "status": "error"}
263+
),
264+
]
265+
)
266+
)(),
267+
}
268+
),
269+
]
270+
)
271+
)()
272+
273+
baseTestPattern(
274+
x,
275+
tyd.encode_ReplspaceapiInitInput,
276+
TypeAdapter(pyd.ReplspaceapiInitInput),
277+
)
278+
279+
149280
def main() -> None:
150281
testAiexecExecInit()
151282
testAgenttoollanguageserverOpendocumentInput()
152283
testAgenttoollanguageserverGetcodesymbolInput()
153284
testShellexecSpawnInput()
154285
testConmanfilesystemPersistInput()
286+
testReplspaceapiInitInput()
155287

156288

157289
if __name__ == "__main__":

scripts/parity/gen.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import random
2+
import string
23
from typing import Callable, Optional, TypeVar
34

45
A = TypeVar("A")
56

67

8+
printable_chars = string.ascii_letters + string.digits
9+
10+
711
def gen_char() -> str:
8-
pos = random.randint(0, 26 * 2)
9-
if pos < 26:
10-
return chr(ord("A") + pos)
11-
else:
12-
return chr(ord("a") + pos - 26)
12+
return random.choice(printable_chars)
1313

1414

1515
def gen_str() -> str:

0 commit comments

Comments
 (0)