Skip to content

Commit a56371d

Browse files
committed
catching errors at prediction time
1 parent 58b4209 commit a56371d

2 files changed

Lines changed: 22 additions & 7 deletions

File tree

vetiver/server.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,27 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw):
191191
else:
192192
requester = requests
193193

194-
if isinstance(data, (pd.DataFrame,pd.Series)):
194+
# TO DO: arrow format
195+
196+
if isinstance(data, pd.DataFrame):
195197
data_json = data.to_json(orient="records")
196198
response = requester.post(endpoint, data=data_json, **kw)
199+
elif isinstance(data, pd.Series):
200+
data_dict = data.to_json()
201+
response = requester.post(endpoint, data=data_dict, **kw)
197202
elif isinstance(data, dict):
198203
response = requester.post(endpoint, json=data, **kw)
199204
else:
200-
raise TypeError(f"Given type is {type(data)}")
201-
205+
try:
206+
response = requester.post(endpoint, json=data, **kw)
207+
except:
208+
raise TypeError(f"Predict expects a DataFrame or dict. Given type is {type(data)}")
209+
202210
response_df = pd.DataFrame.from_dict(response.json())
211+
212+
if isinstance(response_df.iloc[0,0], dict):
213+
if 'type_error.dict' in response_df.iloc[0,0].values():
214+
raise TypeError(f"Predict expects a DataFrame or dict. Given type is {type(data)}")
203215

204216
return response_df
205217

vetiver/tests/test_predict.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def test_predict_sklearn_dict_ptype():
2828

2929
assert isinstance(response, pd.DataFrame), response
3030
assert response.iloc[0,0] == 44.47
31+
assert len(response) == 1
3132

3233

3334
def test_predict_sklearn_no_ptype():
@@ -74,9 +75,10 @@ def test_predict_sklearn_df_check_ptype():
7475
assert len(response) == 100
7576

7677

77-
def test_predict_sklearn_df_no_ptype():
78+
def test_predict_sklearn_series_check_ptype():
7879
np.random.seed(500)
7980
X, y = mock.get_mock_data()
81+
ser = pd.Series(data=[0,0,0])
8082
model = mock.get_mock_model().fit(X, y)
8183
v = VetiverModel(
8284
model=model,
@@ -86,13 +88,14 @@ def test_predict_sklearn_df_no_ptype():
8688
versioned=None,
8789
description="A regression model for testing purposes",
8890
)
89-
app = VetiverAPI(v, check_ptype=False)
91+
app = VetiverAPI(v, check_ptype=True)
9092
client = TestClient(app.app)
9193

92-
response = predict(endpoint=client, data=X)
94+
response = predict(endpoint=client, data=ser)
9395

9496
assert isinstance(response, pd.DataFrame), response
9597
assert response.iloc[0,0] == 44.47
98+
assert len(response) == 1
9699

97100

98101
def test_predict_sklearn_type_error():
@@ -109,7 +112,7 @@ def test_predict_sklearn_type_error():
109112
)
110113
app = VetiverAPI(v, check_ptype=True)
111114
client = TestClient(app.app)
112-
data = '0,0,0'
115+
data = (0,0)
113116

114117
with pytest.raises(TypeError):
115118
predict(endpoint=client, data=data)

0 commit comments

Comments
 (0)