Skip to content

Commit ca9d552

Browse files
feat/finishing typeddict inputs (#95)
Why === We got pretty close to having TypedDicts for river-python inputs before, but had to roll back due to a protocol mismatch. Trying again, and also adding some tests to confirm that at the very least the Pydantic models can decode was was encoded by the TypedDict encoders. It's not a perfect science, but it should be good enough to start building more confidence as we make additional progress. ### The reason for "janky" tests There's a bit of a chicken-and-egg situation when trying to test code generation at runtime. We have three options: - write pytest handlers where each invocation runs the codegen with a temp target (like the shell script does here), writes a static file for each text into that directory, then executes a new python into that directory. The challenge with this is that it would suck to write or maintain. - write pytest handlers which runs the codegen with unique module name targets (like `gen1`, `gen2`, `gen3`, one for each codegen run necessary) and carefully juggle the imports to make sure we don't try to import something that's not there yet. This _might_ be the best option, but I'm not convinced about the ergonomics at the moment. It might be OK though, with highly targeted `.gitignore`'s. - maintain a bespoke test runner, optimize for writing and maintaining these tests, and just acknowledge that we are doing something obscure and difficult. I definitely wrote the tests here in a way that would give some coverage and also provide confidence, while intentionally deferring the above decision so we can keep making progress. in the meantime. What changed ============ - Added some janky tests for comparing the encoding of both models - Fixed many bugs in the TypedDict codegen and encoders Test plan ========= ``` $ bash scripts/parity.sh Using /tmp/river-codegen-parity.bAZ Starting... Verified ```
1 parent 62b236e commit ca9d552

5 files changed

Lines changed: 326 additions & 32 deletions

File tree

mypy.ini

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,13 @@ disallow_untyped_defs = True
44
warn_return_any = True
55

66
[mypy-grpc.*]
7-
ignore_missing_imports = True
7+
ignore_missing_imports = True
8+
9+
[mypy-parity.gen.*]
10+
ignore_missing_imports = True
11+
12+
[mypy-pyd.*]
13+
ignore_missing_imports = True
14+
15+
[mypy-tyd.*]
16+
ignore_missing_imports = True

replit_river/codegen/client.py

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
Set,
1313
Tuple,
1414
Union,
15+
cast,
1516
)
1617

1718
import black
@@ -80,8 +81,17 @@ def reindent(prefix: str, code: str) -> str:
8081
return indent(dedent(code), prefix)
8182

8283

84+
def is_literal(tpe: RiverType) -> bool:
85+
if isinstance(tpe, RiverUnionType):
86+
return all(is_literal(t) for t in tpe.anyOf)
87+
elif isinstance(tpe, RiverConcreteType):
88+
return tpe.type in set(["string", "number", "boolean"])
89+
else:
90+
return False
91+
92+
8393
def encode_type(
84-
type: RiverType, prefix: str, base_model: str = "BaseModel"
94+
type: RiverType, prefix: str, base_model: str
8595
) -> Tuple[str, Sequence[str]]:
8696
chunks: List[str] = []
8797
if isinstance(type, RiverNotType):
@@ -219,14 +229,6 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
219229
type = original_type
220230
any_of: List[str] = []
221231

222-
def is_literal(tpe: RiverType) -> bool:
223-
if isinstance(tpe, RiverUnionType):
224-
return all(is_literal(t) for t in tpe.anyOf)
225-
elif isinstance(tpe, RiverConcreteType):
226-
return tpe.type in set(["string", "number", "boolean"])
227-
else:
228-
return False
229-
230232
typeddict_encoder = []
231233
for i, t in enumerate(type.anyOf):
232234
type_name, type_chunks = encode_type(t, f"{prefix}AnyOf_{i}", base_model)
@@ -273,44 +275,44 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
273275
# Handle the case where type is not specified
274276
typeddict_encoder.append("x")
275277
return ("Any", ())
276-
if type.type == "string":
278+
elif type.type == "string":
277279
if type.const:
278280
typeddict_encoder.append(f"'{type.const}'")
279281
return (f"Literal['{type.const}']", ())
280282
else:
281283
typeddict_encoder.append("x")
282284
return ("str", ())
283-
if type.type == "Uint8Array":
285+
elif type.type == "Uint8Array":
284286
typeddict_encoder.append("x.decode()")
285287
return ("bytes", ())
286-
if type.type == "number":
288+
elif type.type == "number":
287289
if type.const is not None:
288290
# enums are represented as const number in the schema
289291
typeddict_encoder.append(f"{type.const}")
290292
return (f"Literal[{type.const}]", ())
291293
typeddict_encoder.append("x")
292294
return ("float", ())
293-
if type.type == "integer":
295+
elif type.type == "integer":
294296
if type.const is not None:
295297
# enums are represented as const number in the schema
296298
typeddict_encoder.append(f"{type.const}")
297299
return (f"Literal[{type.const}]", ())
298300
typeddict_encoder.append("x")
299301
return ("int", ())
300-
if type.type == "boolean":
302+
elif type.type == "boolean":
301303
typeddict_encoder.append("x")
302304
return ("bool", ())
303-
if type.type == "null":
305+
elif type.type == "null":
304306
typeddict_encoder.append("None")
305307
return ("None", ())
306-
if type.type == "Date":
308+
elif type.type == "Date":
307309
typeddict_encoder.append("TODO: dstewart")
308310
return ("datetime.datetime", ())
309-
if type.type == "array" and type.items:
311+
elif type.type == "array" and type.items:
310312
type_name, type_chunks = encode_type(type.items, prefix, base_model)
311313
typeddict_encoder.append("TODO: dstewart")
312314
return (f"List[{type_name}]", type_chunks)
313-
if (
315+
elif (
314316
type.type == "object"
315317
and type.patternProperties
316318
and "^(.*)$" in type.patternProperties
@@ -323,7 +325,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
323325
assert type.type == "object", type.type
324326

325327
current_chunks: List[str] = [f"class {prefix}({base_model}):"]
328+
# For the encoder path, do we need "x" to be bound?
329+
# lambda x: ... vs lambda _: {}
330+
needs_binding = False
326331
if type.properties:
332+
needs_binding = True
327333
typeddict_encoder.append("{")
328334
for name, prop in type.properties.items():
329335
typeddict_encoder.append(f"'{name}':")
@@ -353,18 +359,35 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
353359
)
354360
if name not in prop.required:
355361
typeddict_encoder.append(
356-
f"if x['{safe_name}'] else None"
362+
dedent(
363+
f"""
364+
if '{safe_name}' in x
365+
and x['{safe_name}'] is not None
366+
else None
367+
"""
368+
)
357369
)
358370
elif prop.type == "array":
359-
assert type_name.startswith(
360-
"List["
361-
) # in case we change to list[...]
362-
_inner_type_name = type_name[len("List[") : -len("]")]
363-
typeddict_encoder.append(
364-
f"[encode_{_inner_type_name}(y) for y in x['{name}']]"
365-
)
371+
items = cast(RiverConcreteType, prop).items
372+
assert items, "Somehow items was none"
373+
if is_literal(cast(RiverType, items)):
374+
typeddict_encoder.append(f"x['{name}']")
375+
else:
376+
assert type_name.startswith(
377+
"List["
378+
) # in case we change to list[...]
379+
_inner_type_name = type_name[len("List[") : -len("]")]
380+
typeddict_encoder.append(
381+
f"""[
382+
encode_{_inner_type_name}(y)
383+
for y in x['{name}']
384+
]"""
385+
)
366386
else:
367-
typeddict_encoder.append(f"x['{safe_name}']")
387+
if name in prop.required:
388+
typeddict_encoder.append(f"x['{safe_name}']")
389+
else:
390+
typeddict_encoder.append(f"x.get('{safe_name}')")
368391

369392
if name == "$kind":
370393
# If the field is a literal, the Python type-checker will complain
@@ -403,8 +426,9 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
403426
current_chunks.append("")
404427

405428
if base_model == "TypedDict":
429+
binding = "x" if needs_binding else "_"
406430
current_chunks = (
407-
[f"encode_{prefix}: Callable[['{prefix}'], Any] = (lambda x: "]
431+
[f"encode_{prefix}: Callable[['{prefix}'], Any] = (lambda {binding}: "]
408432
+ typeddict_encoder
409433
+ [")"]
410434
+ current_chunks
@@ -449,7 +473,7 @@ def generate_river_client_module(
449473

450474
if schema_root.handshakeSchema is not None:
451475
(handshake_type, handshake_chunks) = encode_type(
452-
schema_root.handshakeSchema, "HandshakeSchema"
476+
schema_root.handshakeSchema, "HandshakeSchema", "BaseModel"
453477
)
454478
chunks.extend(handshake_chunks)
455479
else:
@@ -482,7 +506,9 @@ def __init__(self, client: river.Client[{handshake_type}]):
482506
)
483507
chunks.extend(input_chunks)
484508
output_type, output_chunks = encode_type(
485-
procedure.output, f"{schema_name.title()}{name.title()}Output"
509+
procedure.output,
510+
f"{schema_name.title()}{name.title()}Output",
511+
"BaseModel",
486512
)
487513
chunks.extend(output_chunks)
488514
if procedure.errors:
@@ -517,7 +543,20 @@ def __init__(self, client: river.Client[{handshake_type}]):
517543
""".rstrip()
518544

519545
if typed_dict_inputs:
520-
render_input_method = f"encode_{input_type}"
546+
if is_literal(procedure.input):
547+
render_input_method = "lambda x: x"
548+
elif isinstance(
549+
procedure.input, RiverConcreteType
550+
) and procedure.input.type in ["array"]:
551+
assert input_type.startswith(
552+
"List["
553+
) # in case we change to list[...]
554+
_input_type_name = input_type[len("List[") : -len("]")]
555+
render_input_method = (
556+
f"lambda xs: [encode_{_input_type_name}(x) for x in xs]"
557+
)
558+
else:
559+
render_input_method = f"encode_{input_type}"
521560
else:
522561
render_input_method = f"""\
523562
lambda x: TypeAdapter({input_type})

scripts/parity.sh

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/usr/bin/env bash
2+
#
3+
# parity.sh: Generate Pydantic and TypedDict models and check for deep equality.
4+
# This script expects that ai-infra is cloned alongside river-python.
5+
6+
set -e
7+
8+
scripts="$(dirname "$0")"
9+
cd "${scripts}/.."
10+
11+
root="$(mktemp -d --tmpdir 'river-codegen-parity.XXX')"
12+
mkdir "$root/src"
13+
14+
echo "Using $root" >&2
15+
16+
function cleanup {
17+
if [ -z "${DEBUG}" ]; then
18+
echo "Cleaning up..." >&2
19+
rm -rfv "${root}" >&2
20+
fi
21+
}
22+
trap "cleanup" 0 2 3 15
23+
24+
gen() {
25+
fname="$1"; shift
26+
name="$1"; shift
27+
poetry run python -m replit_river.codegen \
28+
client \
29+
--output "${root}/src/${fname}" \
30+
--client-name "${name}" \
31+
../ai-infra/pkgs/pid2_client/src/schema/schema.json \
32+
"$@"
33+
}
34+
35+
gen tyd.py Pid2TypedDict --typed-dict-inputs
36+
gen pyd.py Pid2Pydantic
37+
38+
PYTHONPATH="${root}/src:${scripts}"
39+
poetry run bash -c "MYPYPATH='$PYTHONPATH' mypy -m parity.check_parity"
40+
poetry run bash -c "PYTHONPATH='$PYTHONPATH' python -m parity.check_parity"

0 commit comments

Comments
 (0)