Skip to content

Commit 58b4209

Browse files
committed
tests
1 parent f490010 commit 58b4209

2 files changed

Lines changed: 125 additions & 3 deletions

File tree

vetiver/server.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from fastapi.responses import HTMLResponse, RedirectResponse
33
from fastapi.staticfiles import StaticFiles
44
from fastapi.openapi.utils import get_openapi
5+
from fastapi import testclient
56

67
import uvicorn
78
import requests
@@ -184,13 +185,19 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw):
184185
dict
185186
Key: endpoint_name Value: Output of endpoint_fx, in list format
186187
"""
188+
if isinstance(endpoint, testclient.TestClient):
189+
requester = endpoint
190+
endpoint = "/predict/"
191+
else:
192+
requester = requests
193+
187194
if isinstance(data, (pd.DataFrame,pd.Series)):
188195
data_json = data.to_json(orient="records")
189-
response = requests.post(endpoint, data=data_json, **kw)
196+
response = requester.post(endpoint, data=data_json, **kw)
190197
elif isinstance(data, dict):
191-
response = requests.post(endpoint, json=data, **kw)
198+
response = requester.post(endpoint, json=data, **kw)
192199
else:
193-
raise TypeError(f"Accepted data types are dictionary or DataFrame, given type is {type(data)} \n {data}")
200+
raise TypeError(f"Given type is {type(data)}")
194201

195202
response_df = pd.DataFrame.from_dict(response.json())
196203

vetiver/tests/test_predict.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import pytest
2+
3+
import numpy as np
4+
import pandas as pd
5+
from fastapi.testclient import TestClient
6+
7+
from vetiver import mock, VetiverModel, VetiverAPI
8+
from vetiver.server import predict
9+
10+
11+
def test_predict_sklearn_dict_ptype():
12+
np.random.seed(500)
13+
X, y = mock.get_mock_data()
14+
model = mock.get_mock_model().fit(X, y)
15+
v = VetiverModel(
16+
model=model,
17+
save_ptype=True,
18+
ptype_data=X,
19+
model_name="my_model",
20+
versioned=None,
21+
description="A regression model for testing purposes",
22+
)
23+
app = VetiverAPI(v, check_ptype=True)
24+
client = TestClient(app.app)
25+
data = {"B": 0, "C": 0, "D": 0}
26+
27+
response = predict(endpoint=client, data=data)
28+
29+
assert isinstance(response, pd.DataFrame), response
30+
assert response.iloc[0,0] == 44.47
31+
32+
33+
def test_predict_sklearn_no_ptype():
34+
np.random.seed(500)
35+
X, y = mock.get_mock_data()
36+
model = mock.get_mock_model().fit(X, y)
37+
v = VetiverModel(
38+
model=model,
39+
save_ptype=True,
40+
ptype_data=X,
41+
model_name="my_model",
42+
versioned=None,
43+
description="A regression model for testing purposes",
44+
)
45+
app = VetiverAPI(v, check_ptype=False)
46+
client = TestClient(app.app)
47+
48+
response = predict(endpoint=client, data=X)
49+
50+
assert isinstance(response, pd.DataFrame), response
51+
assert response.iloc[0,0] == 44.47
52+
assert len(response) == 100
53+
54+
55+
def test_predict_sklearn_df_check_ptype():
56+
np.random.seed(500)
57+
X, y = mock.get_mock_data()
58+
model = mock.get_mock_model().fit(X, y)
59+
v = VetiverModel(
60+
model=model,
61+
save_ptype=True,
62+
ptype_data=X,
63+
model_name="my_model",
64+
versioned=None,
65+
description="A regression model for testing purposes",
66+
)
67+
app = VetiverAPI(v, check_ptype=True)
68+
client = TestClient(app.app)
69+
70+
response = predict(endpoint=client, data=X)
71+
72+
assert isinstance(response, pd.DataFrame), response
73+
assert response.iloc[0,0] == 44.47
74+
assert len(response) == 100
75+
76+
77+
def test_predict_sklearn_df_no_ptype():
78+
np.random.seed(500)
79+
X, y = mock.get_mock_data()
80+
model = mock.get_mock_model().fit(X, y)
81+
v = VetiverModel(
82+
model=model,
83+
save_ptype=True,
84+
ptype_data=X,
85+
model_name="my_model",
86+
versioned=None,
87+
description="A regression model for testing purposes",
88+
)
89+
app = VetiverAPI(v, check_ptype=False)
90+
client = TestClient(app.app)
91+
92+
response = predict(endpoint=client, data=X)
93+
94+
assert isinstance(response, pd.DataFrame), response
95+
assert response.iloc[0,0] == 44.47
96+
97+
98+
def test_predict_sklearn_type_error():
99+
np.random.seed(500)
100+
X, y = mock.get_mock_data()
101+
model = mock.get_mock_model().fit(X, y)
102+
v = VetiverModel(
103+
model=model,
104+
save_ptype=True,
105+
ptype_data=X,
106+
model_name="my_model",
107+
versioned=None,
108+
description="A regression model for testing purposes",
109+
)
110+
app = VetiverAPI(v, check_ptype=True)
111+
client = TestClient(app.app)
112+
data = '0,0,0'
113+
114+
with pytest.raises(TypeError):
115+
predict(endpoint=client, data=data)

0 commit comments

Comments
 (0)