Skip to content

Commit b533875

Browse files
authored
Merge pull request #174 from rstudio/prototype-endpoint
2 parents 4d6090b + 9e8e1e9 commit b533875

6 files changed

Lines changed: 134 additions & 2 deletions

File tree

.github/workflows/tests.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,23 @@ jobs:
115115
run: |
116116
pytest vetiver/tests/test_sklearn.py
117117
118+
test-pydantic-old:
119+
name: "Test pydantic v1"
120+
runs-on: ubuntu-latest
121+
steps:
122+
- uses: actions/checkout@v3
123+
- uses: actions/setup-python@v4
124+
with:
125+
python-version: 3.8
126+
- name: Install dependencies
127+
run: |
128+
python -m pip install --upgrade pip
129+
python -m pip install -e .[dev]
130+
python -m pip install 'pydantic<2.0.0'
131+
132+
- name: Run tests
133+
run: |
134+
make test
118135
typecheck:
119136
runs-on: ubuntu-latest
120137

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,5 @@ NOTES.md
144144
.Rproj.user
145145

146146
/.quarto/
147+
148+
/.luarc.json

vetiver/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
# Change to import.metadata when minimum python>=3.8
33
from importlib_metadata import version as _version
44

5-
from .prototype import vetiver_create_prototype, vetiver_create_ptype # noqa
5+
from .prototype import ( # noqa
6+
vetiver_create_prototype, # noqa
7+
vetiver_create_ptype, # noqa
8+
InvalidPTypeError, # noqa
9+
) # noqa
610
from .vetiver_model import VetiverModel # noqa
711
from .server import VetiverAPI, vetiver_endpoint, predict # noqa
812
from .mock import get_mock_data, get_mock_model # noqa

vetiver/prototype.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,22 @@ def _(data: dict):
175175
data : dict
176176
Dictionary
177177
"""
178+
179+
# if it comes from vetiver's /prototype endpoint
180+
dict_data = {}
181+
if data.keys() >= {"properties", "title", "type"}:
182+
# automatically create for simple prototypes
183+
try:
184+
for key, value in data["properties"].items():
185+
dict_data.update({key: (type(value["default"]), value["default"])})
186+
# error for complex objects
187+
except KeyError:
188+
raise InvalidPTypeError(
189+
"Failed to create dict prototype for this data. "
190+
"Please use pandas DataFrame or pydantic BaseModel."
191+
)
192+
return create_prototype(**dict_data)
193+
178194
return create_prototype(**_to_field(data))
179195

180196

vetiver/server.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import re
55
import httpx
6+
import json
67
import pandas as pd
78
import requests
89
import uvicorn
@@ -27,6 +28,8 @@ class VetiverAPI:
2728
----------
2829
model : VetiverModel
2930
Model to be deployed in API
31+
show_prototype: bool = True
32+
3033
check_prototype : bool
3134
Determine if data prototype should be enforced
3235
app_factory :
@@ -60,6 +63,7 @@ class VetiverAPI:
6063
def __init__(
6164
self,
6265
model: VetiverModel,
66+
show_prototype: bool = True,
6367
check_prototype: bool = True,
6468
app_factory=FastAPI,
6569
**kwargs,
@@ -80,6 +84,7 @@ def __init__(
8084
self.model.prototype = self.model.ptype
8185
delattr(self.model, "ptype")
8286

87+
self.show_prototype = show_prototype
8388
self.check_prototype = check_prototype
8489

8590
self._init_app()
@@ -114,6 +119,23 @@ async def get_metadata():
114119
"""Get metadata from model"""
115120
return self.model.metadata.to_dict()
116121

122+
if self.show_prototype is True:
123+
124+
@app.get("/prototype")
125+
async def get_prototype():
126+
# to handle pydantic<2 and >=2
127+
prototype_schema = getattr(
128+
self.model.prototype,
129+
"model_json_schema",
130+
self.model.prototype.schema_json,
131+
)()
132+
# pydantic<2 returns a string, need to handle to json format
133+
if isinstance(prototype_schema, str):
134+
prototype_schema = json.loads(prototype_schema)
135+
for key, value in prototype_schema["properties"].items():
136+
value.pop("title", None)
137+
return prototype_schema
138+
117139
self.vetiver_post(
118140
self.model.handler_predict, "predict", check_prototype=self.check_prototype
119141
)

vetiver/tests/test_server.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
1-
from vetiver import mock, VetiverModel, VetiverAPI
1+
from vetiver import (
2+
mock,
3+
VetiverModel,
4+
VetiverAPI,
5+
vetiver_create_prototype,
6+
InvalidPTypeError,
7+
)
8+
from pydantic import BaseModel, conint
29
from fastapi.testclient import TestClient
10+
import numpy as np
311
import pytest
412
import sys
513

614

715
@pytest.fixture
816
def vetiver_model():
17+
np.random.seed(500)
918
X, y = mock.get_mock_data()
1019
model = mock.get_mock_model().fit(X, y)
1120
v = VetiverModel(
@@ -25,6 +34,31 @@ def client(vetiver_model):
2534
return TestClient(app.app)
2635

2736

37+
@pytest.fixture
38+
def complex_prototype_model():
39+
np.random.seed(500)
40+
41+
class CustomPrototype(BaseModel):
42+
B: conint(gt=42)
43+
C: conint(gt=42)
44+
D: conint(gt=42)
45+
46+
X, y = mock.get_mock_data()
47+
model = mock.get_mock_model().fit(X, y)
48+
v = VetiverModel(
49+
model=model,
50+
# move to model_construct for pydantic 3
51+
prototype_data=CustomPrototype.construct(),
52+
model_name="my_model",
53+
versioned=None,
54+
description="A regression model for testing purposes",
55+
)
56+
# dont actually want to make predictions, just for looking at schema
57+
app = VetiverAPI(v, check_prototype=False)
58+
59+
return TestClient(app.app)
60+
61+
2862
def test_get_ping(client):
2963
response = client.get("/ping")
3064
assert response.status_code == 200, response.text
@@ -46,3 +80,40 @@ def test_get_metadata(client):
4680
"required_pkgs": ["scikit-learn"],
4781
"python_version": list(sys.version_info), # JSON will return a list
4882
}
83+
84+
85+
def test_get_prototype(client, vetiver_model):
86+
response = client.get("/prototype")
87+
assert response.status_code == 200, response.text
88+
assert response.json() == {
89+
"properties": {
90+
"B": {"default": 55, "type": "integer"},
91+
"C": {"default": 65, "type": "integer"},
92+
"D": {"default": 17, "type": "integer"},
93+
},
94+
"title": "prototype",
95+
"type": "object",
96+
}
97+
98+
assert (
99+
vetiver_model.prototype.construct().dict()
100+
== vetiver_create_prototype(response.json()).construct().dict()
101+
)
102+
103+
104+
def test_complex_prototype(complex_prototype_model):
105+
response = complex_prototype_model.get("/prototype")
106+
assert response.status_code == 200, response.text
107+
assert response.json() == {
108+
"properties": {
109+
"B": {"exclusiveMinimum": 42, "type": "integer"},
110+
"C": {"exclusiveMinimum": 42, "type": "integer"},
111+
"D": {"exclusiveMinimum": 42, "type": "integer"},
112+
},
113+
"required": ["B", "C", "D"],
114+
"title": "CustomPrototype",
115+
"type": "object",
116+
}
117+
118+
with pytest.raises(InvalidPTypeError):
119+
vetiver_create_prototype(response.json())

0 commit comments

Comments
 (0)