Skip to content

Commit a466e36

Browse files
authored
fix: pinned bool columns show up as None (#225)
* dont generalize falsiness * add test
1 parent f3e53e6 commit a466e36

2 files changed

Lines changed: 38 additions & 1 deletion

File tree

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,41 @@ def test_vetiver_model_dict_like_prototype(prototype_data):
110110
assert json_schema == expected
111111

112112

113+
@pytest.mark.skipif(
114+
pydantic.__version__.startswith("1"), reason="only run for pydantic v2"
115+
)
116+
@pytest.mark.parametrize(
117+
"prototype_data,expected",
118+
[
119+
(
120+
{"B": 0, "C": False, "D": None},
121+
{
122+
"properties": {
123+
"B": {"example": 0, "title": "B", "type": "integer"},
124+
"C": {"example": False, "title": "C", "type": "boolean"},
125+
"D": {"example": None, "title": "D", "type": "null"},
126+
},
127+
"required": ["B", "C", "D"],
128+
"title": "prototype",
129+
"type": "object",
130+
},
131+
)
132+
],
133+
)
134+
def test_falsy_prototypes(prototype_data, expected):
135+
v = VetiverModel(
136+
model=model,
137+
prototype_data=prototype_data,
138+
model_name="model",
139+
versioned=None,
140+
description=None,
141+
metadata=None,
142+
)
143+
144+
assert isinstance(v.prototype.construct(), pydantic.BaseModel)
145+
assert v.prototype.model_json_schema() == expected
146+
147+
113148
@pytest.mark.parametrize("prototype_data", [MockPrototype(B=4, C=0, D=0), None])
114149
def test_vetiver_model_prototypes(prototype_data):
115150
v = VetiverModel(

vetiver/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def serialize_prototype(prototype):
6464

6565
serialized_schema = dict()
6666
for key, value in schema.items():
67-
serialized_schema[key] = value.get("example") or value.get("default")
67+
example = value.get("example", None)
68+
default = value.get("default", None)
69+
serialized_schema[key] = example if example is not None else default
6870

6971
return json.dumps(serialized_schema)

0 commit comments

Comments
 (0)