Skip to content

Commit 3f13ed5

Browse files
authored
add Field examples to model prototypes (#210)
* add Field examples * add test * clean up test * handle serializing prototypes + better checking * handle pydantic v1 * handle json strings, tests
1 parent e1ab9bc commit 3f13ed5

7 files changed

Lines changed: 121 additions & 29 deletions

File tree

vetiver/pin_read_write.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .vetiver_model import VetiverModel
22
from .meta import VetiverMeta
3-
from .utils import inform
3+
from .utils import inform, serialize_prototype
44
import warnings
55
import logging
66

@@ -72,10 +72,16 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True):
7272
"user": model.metadata.user,
7373
"vetiver_meta": {
7474
"required_pkgs": model.metadata.required_pkgs,
75-
"prototype": None if not model.prototype else model.prototype().json(),
76-
"python_version": None
77-
if not model.metadata.python_version
78-
else list(model.metadata.python_version),
75+
"prototype": (
76+
None
77+
if not model.prototype
78+
else serialize_prototype(model.prototype)
79+
),
80+
"python_version": (
81+
None
82+
if not model.metadata.python_version
83+
else list(model.metadata.python_version)
84+
),
7985
},
8086
},
8187
versioned=versioned,

vetiver/prototype.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pandas as pd
1010
import numpy as np
1111
import pydantic
12+
from pydantic import Field
1213
from warnings import warn
1314
from .types import create_prototype
1415

@@ -159,7 +160,8 @@ def _item(value):
159160
# if its a numpy type, we have to take the Python type due to Pydantic
160161

161162
dict_data = {
162-
f"{key}": (type(value.item()), _item(value)) for key, value in dict_data.items()
163+
f"{key}": (type(value.item()), Field(..., example=_item(value)))
164+
for key, value in dict_data.items()
163165
}
164166
prototype = create_prototype(**dict_data)
165167
return prototype
@@ -182,7 +184,14 @@ def _(data: dict):
182184
# automatically create for simple prototypes
183185
try:
184186
for key, value in data["properties"].items():
185-
dict_data.update({key: (type(value["default"]), value["default"])})
187+
dict_data.update(
188+
{
189+
key: (
190+
type(value["example"]),
191+
Field(..., example=value["example"]),
192+
)
193+
}
194+
)
186195
# error for complex objects
187196
except KeyError:
188197
raise InvalidPTypeError(
@@ -223,5 +232,5 @@ def _(data: NoneType):
223232
def _to_field(data):
224233
basemodel_input = dict()
225234
for key, value in data.items():
226-
basemodel_input[key] = (type(value), value)
235+
basemodel_input[key] = (type(value), Field(..., example=value))
227236
return basemodel_input

vetiver/tests/test_build_vetiver_model.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,36 @@ def test_vetiver_model_array_prototype():
4141
description=None,
4242
metadata=None,
4343
)
44+
try:
45+
json_schema = v.prototype.model_json_schema()
46+
expected = {
47+
"properties": {
48+
"0": {"example": 96, "title": "0", "type": "integer"},
49+
"1": {"example": 11, "title": "1", "type": "integer"},
50+
"2": {"example": 33, "title": "2", "type": "integer"},
51+
},
52+
"required": ["0", "1", "2"],
53+
"title": "prototype",
54+
"type": "object",
55+
}
56+
except AttributeError: # pydantic v1
57+
json_schema = v.prototype.schema_json()
58+
expected = '{\
59+
"title": "prototype", \
60+
"type": "object", \
61+
"properties": {"0": {"title": "0", "example": 96, "type": "integer"}, \
62+
"1": {"title": "1", "example": 11, "type": "integer"}, \
63+
"2": {"title": "2", "example": 33, "type": "integer"}}, \
64+
"required": ["0", "1", "2"]}'
4465

4566
assert v.model == model
4667
assert issubclass(v.prototype, vetiver.Prototype)
4768
# change to model_construct for pydantic v3
4869
assert isinstance(v.prototype.construct(), pydantic.BaseModel)
49-
assert v.prototype.construct().__dict__ == {"0": 96, "1": 11, "2": 33}
70+
assert json_schema == expected
5071

5172

52-
@pytest.mark.parametrize("prototype_data", [{"B": 96, "C": 0, "D": 0}, X_df])
73+
@pytest.mark.parametrize("prototype_data", [{"B": 96, "C": 11, "D": 33}, X_df])
5374
def test_vetiver_model_dict_like_prototype(prototype_data):
5475
v = VetiverModel(
5576
model=model,
@@ -63,7 +84,30 @@ def test_vetiver_model_dict_like_prototype(prototype_data):
6384
assert v.model == model
6485
# change to model_construct for pydantic v3
6586
assert isinstance(v.prototype.construct(), pydantic.BaseModel)
66-
assert v.prototype.construct().B == 96
87+
88+
try:
89+
json_schema = v.prototype.model_json_schema()
90+
expected = {
91+
"properties": {
92+
"B": {"example": 96, "title": "B", "type": "integer"},
93+
"C": {"example": 11, "title": "C", "type": "integer"},
94+
"D": {"example": 33, "title": "D", "type": "integer"},
95+
},
96+
"required": ["B", "C", "D"],
97+
"title": "prototype",
98+
"type": "object",
99+
}
100+
except AttributeError: # pydantic v1
101+
json_schema = v.prototype.schema_json()
102+
expected = '{\
103+
"title": "prototype", \
104+
"type": "object", \
105+
"properties": {"B": {"title": "B", "example": 96, "type": "integer"}, \
106+
"C": {"title": "C", "example": 11, "type": "integer"}, \
107+
"D": {"title": "D", "example": 33, "type": "integer"}}, \
108+
"required": ["B", "C", "D"]}'
109+
110+
assert json_schema == expected
67111

68112

69113
@pytest.mark.parametrize("prototype_data", [MockPrototype(B=4, C=0, D=0), None])

vetiver/tests/test_server.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
vetiver_create_prototype,
66
InvalidPTypeError,
77
vetiver_endpoint,
8+
predict,
89
)
910
from pydantic import BaseModel, conint
1011
from fastapi.testclient import TestClient
@@ -14,7 +15,7 @@
1415

1516

1617
@pytest.fixture
17-
def vetiver_model():
18+
def model():
1819
np.random.seed(500)
1920
X, y = mock.get_mock_data()
2021
model = mock.get_mock_model().fit(X, y)
@@ -29,14 +30,7 @@ def vetiver_model():
2930

3031

3132
@pytest.fixture
32-
def client(vetiver_model):
33-
app = VetiverAPI(vetiver_model)
34-
35-
return TestClient(app.app)
36-
37-
38-
@pytest.fixture
39-
def complex_prototype_model():
33+
def complex_prototype_client():
4034
np.random.seed(500)
4135

4236
class CustomPrototype(BaseModel):
@@ -83,27 +77,28 @@ def test_get_metadata(client):
8377
}
8478

8579

86-
def test_get_prototype(client, vetiver_model):
80+
def test_get_prototype(client, model):
8781
response = client.get("/prototype")
8882
assert response.status_code == 200, response.text
8983
assert response.json() == {
9084
"properties": {
91-
"B": {"default": 55, "type": "integer"},
92-
"C": {"default": 65, "type": "integer"},
93-
"D": {"default": 17, "type": "integer"},
85+
"B": {"example": 55, "type": "integer"},
86+
"C": {"example": 65, "type": "integer"},
87+
"D": {"example": 17, "type": "integer"},
9488
},
89+
"required": ["B", "C", "D"],
9590
"title": "prototype",
9691
"type": "object",
9792
}
9893

9994
assert (
100-
vetiver_model.prototype.construct().dict()
95+
model.prototype.construct().dict()
10196
== vetiver_create_prototype(response.json()).construct().dict()
10297
)
10398

10499

105-
def test_complex_prototype(complex_prototype_model):
106-
response = complex_prototype_model.get("/prototype")
100+
def test_complex_prototype(complex_prototype_client):
101+
response = complex_prototype_client.get("/prototype")
107102
assert response.status_code == 200, response.text
108103
assert response.json() == {
109104
"properties": {
@@ -120,6 +115,11 @@ def test_complex_prototype(complex_prototype_model):
120115
vetiver_create_prototype(response.json())
121116

122117

118+
def test_predict_wrong_input(client):
119+
with pytest.raises(TypeError):
120+
predict(endpoint="/predict/", data=[{"B": 43, "C": 43}], test_client=client)
121+
122+
123123
def test_vetiver_endpoint():
124124
url_raw = "http://127.0.0.1:8000/predict/"
125125
url = vetiver_endpoint(url_raw)

vetiver/tests/test_sklearn.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ def model() -> VetiverModel:
2626

2727
@pytest.mark.parametrize(
2828
"data,expected_length",
29-
[([{"B": 0, "C": 0, "D": 0}], 1), (pd.Series(data=[0, 0, 0]), 1), (X, 100)],
29+
[
30+
([{"B": 0, "C": 0, "D": 0}], 1),
31+
(pd.Series(data=[0, 0, 0], index=["B", "C", "D"]), 1),
32+
(X, 100),
33+
],
3034
)
3135
def test_predict_sklearn_ptype(data, expected_length, client):
3236
response = predict(endpoint="/predict/", data=data, test_client=client)

vetiver/tests/test_spacy.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,22 @@ def test_bad_prototype_shape(data, spacy_model):
6060
def test_good_prototype_shape(data, spacy_model):
6161
v = vetiver.VetiverModel(spacy_model, "animals", prototype_data=data)
6262

63-
assert v.prototype.construct().dict() == {"col": "1"}
63+
try:
64+
model_schema = v.prototype.model_json_schema()
65+
expected = {
66+
"properties": {
67+
"col": {"example": "1", "title": "Col", "type": "string"},
68+
},
69+
"required": ["col"],
70+
"title": "prototype",
71+
"type": "object",
72+
}
73+
except AttributeError: # pydantic v1
74+
model_schema = v.prototype.schema_json()
75+
expected = '{"title": "prototype", "type": "object", "properties": \
76+
{"col": {"title": "Col", "example": "1", "type": "string"}}, "required": ["col"]}'
77+
78+
assert model_schema == expected
6479

6580

6681
def test_vetiver_predict_with_prototype(client: TestClient):

vetiver/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
import os
55
import subprocess
6+
import json
67
from types import SimpleNamespace
78

89
no_notebook = False
@@ -53,3 +54,16 @@ def get_workbench_path(port):
5354
return path
5455
else:
5556
return None
57+
58+
59+
def serialize_prototype(prototype):
60+
try:
61+
schema = prototype.model_json_schema().get("properties")
62+
except AttributeError: # pydantic v1
63+
schema = json.loads(prototype.schema_json()).get("properties")
64+
65+
serialized_schema = dict()
66+
for key, value in schema.items():
67+
serialized_schema[key] = value.get("example") or value.get("default")
68+
69+
return json.dumps(serialized_schema)

0 commit comments

Comments
 (0)