|
81 | 81 |
|
82 | 82 | from pydantic import BaseModel, Field, TypeAdapter, WrapValidator |
83 | 83 | from replit_river.error_schema import RiverError |
84 | | -from replit_river.client import RiverUnknownValue, translate_unknown_value |
| 84 | +from replit_river.client import RiverUnknownError, translate_unknown_error, \ |
| 85 | + RiverUnknownValue, translate_unknown_value |
85 | 86 |
|
86 | 87 | import replit_river as river |
87 | 88 |
|
@@ -154,6 +155,20 @@ def encode_type( |
154 | 155 | in_module: list[ModuleName], |
155 | 156 | permit_unknown_members: bool, |
156 | 157 | ) -> tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]: |
| 158 | + def _make_open_union_type_expr(one_of: list[TypeExpression]) -> OpenUnionTypeExpr: |
| 159 | + if base_model == "RiverError": |
| 160 | + return OpenUnionTypeExpr( |
| 161 | + UnionTypeExpr(one_of), |
| 162 | + fallback_type="RiverUnknownError", |
| 163 | + validator_function="translate_unknown_error", |
| 164 | + ) |
| 165 | + else: |
| 166 | + return OpenUnionTypeExpr( |
| 167 | + UnionTypeExpr(one_of), |
| 168 | + fallback_type="RiverUnknownValue", |
| 169 | + validator_function="translate_unknown_value", |
| 170 | + ) |
| 171 | + |
157 | 172 | encoder_name: TypeName | None = None # defining this up here to placate mypy |
158 | 173 | chunks: list[FileContents] = [] |
159 | 174 | if isinstance(type, RiverNotType): |
@@ -304,7 +319,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: |
304 | 319 | ) |
305 | 320 | union: TypeExpression |
306 | 321 | if permit_unknown_members: |
307 | | - union = OpenUnionTypeExpr(UnionTypeExpr(one_of)) |
| 322 | + union = _make_open_union_type_expr(one_of) |
308 | 323 | else: |
309 | 324 | union = UnionTypeExpr(one_of) |
310 | 325 | chunks.append( |
@@ -383,7 +398,7 @@ def {_field_name}( |
383 | 398 | ) |
384 | 399 | raise ValueError(f"What does it mean to have {_o2} here?") |
385 | 400 | if permit_unknown_members: |
386 | | - union = OpenUnionTypeExpr(UnionTypeExpr(any_of)) |
| 401 | + union = _make_open_union_type_expr(any_of) |
387 | 402 | else: |
388 | 403 | union = UnionTypeExpr(any_of) |
389 | 404 | if is_literal(type): |
@@ -795,17 +810,18 @@ def _type_adapter_definition( |
795 | 810 | _type: TypeExpression, |
796 | 811 | module_info: list[ModuleName], |
797 | 812 | ) -> tuple[list[TypeName], list[ModuleName], list[FileContents]]: |
| 813 | + varname = render_type_expr(type_adapter_name) |
798 | 814 | rendered_type_expr = render_type_expr(_type) |
799 | 815 | return ( |
800 | 816 | [type_adapter_name], |
801 | 817 | module_info, |
802 | 818 | [ |
803 | 819 | FileContents( |
804 | 820 | dedent(f""" |
805 | | - {render_type_expr(type_adapter_name)}: TypeAdapter[Any] = ( |
806 | | - TypeAdapter({rendered_type_expr}) |
807 | | - ) |
808 | | - """) |
| 821 | + {varname}: TypeAdapter[{rendered_type_expr}] = ( |
| 822 | + TypeAdapter({rendered_type_expr}) |
| 823 | + ) |
| 824 | + """) |
809 | 825 | ) |
810 | 826 | ], |
811 | 827 | ) |
|
0 commit comments