Skip to content

Commit e9b8b6f

Browse files
authored
Merge pull request #24 from isabelizimm/dev-predict-return-df
return dataframes in `predict()`
2 parents d3f8205 + a56371d commit e9b8b6f

2 files changed

Lines changed: 147 additions & 6 deletions

File tree

vetiver/server.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from fastapi.responses import HTMLResponse, RedirectResponse
33
from fastapi.staticfiles import StaticFiles
44
from fastapi.openapi.utils import get_openapi
5+
from fastapi import testclient
56

67
import uvicorn
78
import requests
@@ -169,7 +170,7 @@ def _custom_openapi(self):
169170
self.app.openapi_schema = openapi_schema
170171
return self.app.openapi_schema
171172

172-
def predict(endpoint, data: dict, **kw):
173+
def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw):
173174
"""Make a prediction from model endpoint
174175
175176
Parameters
@@ -184,13 +185,35 @@ def predict(endpoint, data: dict, **kw):
184185
dict
185186
Key: endpoint_name Value: Output of endpoint_fx, in list format
186187
"""
188+
if isinstance(endpoint, testclient.TestClient):
189+
requester = endpoint
190+
endpoint = "/predict/"
191+
else:
192+
requester = requests
193+
194+
# TO DO: arrow format
195+
187196
if isinstance(data, pd.DataFrame):
188-
data = data.to_json(orient="records")
189-
response = requests.post(endpoint, data=data, **kw)
197+
data_json = data.to_json(orient="records")
198+
response = requester.post(endpoint, data=data_json, **kw)
199+
elif isinstance(data, pd.Series):
200+
data_dict = data.to_json()
201+
response = requester.post(endpoint, data=data_dict, **kw)
202+
elif isinstance(data, dict):
203+
response = requester.post(endpoint, json=data, **kw)
190204
else:
191-
response = requests.post(endpoint, json=data, **kw)
192-
193-
return response.json()
205+
try:
206+
response = requester.post(endpoint, json=data, **kw)
207+
except:
208+
raise TypeError(f"Predict expects a DataFrame or dict. Given type is {type(data)}")
209+
210+
response_df = pd.DataFrame.from_dict(response.json())
211+
212+
if isinstance(response_df.iloc[0,0], dict):
213+
if 'type_error.dict' in response_df.iloc[0,0].values():
214+
raise TypeError(f"Predict expects a DataFrame or dict. Given type is {type(data)}")
215+
216+
return response_df
194217

195218

196219
def _prepare_data(pred_data):

vetiver/tests/test_predict.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import pytest
2+
3+
import numpy as np
4+
import pandas as pd
5+
from fastapi.testclient import TestClient
6+
7+
from vetiver import mock, VetiverModel, VetiverAPI
8+
from vetiver.server import predict
9+
10+
11+
def test_predict_sklearn_dict_ptype():
12+
np.random.seed(500)
13+
X, y = mock.get_mock_data()
14+
model = mock.get_mock_model().fit(X, y)
15+
v = VetiverModel(
16+
model=model,
17+
save_ptype=True,
18+
ptype_data=X,
19+
model_name="my_model",
20+
versioned=None,
21+
description="A regression model for testing purposes",
22+
)
23+
app = VetiverAPI(v, check_ptype=True)
24+
client = TestClient(app.app)
25+
data = {"B": 0, "C": 0, "D": 0}
26+
27+
response = predict(endpoint=client, data=data)
28+
29+
assert isinstance(response, pd.DataFrame), response
30+
assert response.iloc[0,0] == 44.47
31+
assert len(response) == 1
32+
33+
34+
def test_predict_sklearn_no_ptype():
35+
np.random.seed(500)
36+
X, y = mock.get_mock_data()
37+
model = mock.get_mock_model().fit(X, y)
38+
v = VetiverModel(
39+
model=model,
40+
save_ptype=True,
41+
ptype_data=X,
42+
model_name="my_model",
43+
versioned=None,
44+
description="A regression model for testing purposes",
45+
)
46+
app = VetiverAPI(v, check_ptype=False)
47+
client = TestClient(app.app)
48+
49+
response = predict(endpoint=client, data=X)
50+
51+
assert isinstance(response, pd.DataFrame), response
52+
assert response.iloc[0,0] == 44.47
53+
assert len(response) == 100
54+
55+
56+
def test_predict_sklearn_df_check_ptype():
57+
np.random.seed(500)
58+
X, y = mock.get_mock_data()
59+
model = mock.get_mock_model().fit(X, y)
60+
v = VetiverModel(
61+
model=model,
62+
save_ptype=True,
63+
ptype_data=X,
64+
model_name="my_model",
65+
versioned=None,
66+
description="A regression model for testing purposes",
67+
)
68+
app = VetiverAPI(v, check_ptype=True)
69+
client = TestClient(app.app)
70+
71+
response = predict(endpoint=client, data=X)
72+
73+
assert isinstance(response, pd.DataFrame), response
74+
assert response.iloc[0,0] == 44.47
75+
assert len(response) == 100
76+
77+
78+
def test_predict_sklearn_series_check_ptype():
79+
np.random.seed(500)
80+
X, y = mock.get_mock_data()
81+
ser = pd.Series(data=[0,0,0])
82+
model = mock.get_mock_model().fit(X, y)
83+
v = VetiverModel(
84+
model=model,
85+
save_ptype=True,
86+
ptype_data=X,
87+
model_name="my_model",
88+
versioned=None,
89+
description="A regression model for testing purposes",
90+
)
91+
app = VetiverAPI(v, check_ptype=True)
92+
client = TestClient(app.app)
93+
94+
response = predict(endpoint=client, data=ser)
95+
96+
assert isinstance(response, pd.DataFrame), response
97+
assert response.iloc[0,0] == 44.47
98+
assert len(response) == 1
99+
100+
101+
def test_predict_sklearn_type_error():
102+
np.random.seed(500)
103+
X, y = mock.get_mock_data()
104+
model = mock.get_mock_model().fit(X, y)
105+
v = VetiverModel(
106+
model=model,
107+
save_ptype=True,
108+
ptype_data=X,
109+
model_name="my_model",
110+
versioned=None,
111+
description="A regression model for testing purposes",
112+
)
113+
app = VetiverAPI(v, check_ptype=True)
114+
client = TestClient(app.app)
115+
data = (0,0)
116+
117+
with pytest.raises(TypeError):
118+
predict(endpoint=client, data=data)

0 commit comments

Comments
 (0)