Skip to content

Commit 25baff1

Browse files
committed
refactor tests for api
1 parent 053289d commit 25baff1

9 files changed

Lines changed: 177 additions & 240 deletions

File tree

vetiver/handlers/base.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import List
12
from functools import singledispatch
23
from contextlib import suppress
34
import pandas as pd
@@ -119,8 +120,7 @@ def handler_startup():
119120
"""
120121
...
121122

122-
@singledispatch
123-
def _prepare_data(self, pred_data) -> pd.DataFrame:
123+
def _process_input(self, pred_data) -> pd.DataFrame:
124124
"""Convert prototype to dataframe data
125125
126126
Parameters
@@ -133,19 +133,10 @@ def _prepare_data(self, pred_data) -> pd.DataFrame:
133133
pd.DataFrame
134134
BaseModel data translated into DataFrame
135135
"""
136-
served_data = []
137-
for key, value in pred_data:
138-
served_data.append(value)
139136

140-
return served_data
137+
new = _prepare_data(pred_data)
141138

142-
@_prepare_data.register
143-
def _basemodel(self, pred_data: pydantic.BaseModel):
144-
return pd.DataFrame([dict(s) for s in pred_data])
145-
146-
@_prepare_data.register
147-
def _list(self, pred_data: list):
148-
return pd.DataFrame([dict(s) for s in pred_data])
139+
return new
149140

150141
def handler_predict(self, input_data, check_prototype):
151142
"""Generates method for /predict endpoint in VetiverAPI
@@ -179,3 +170,37 @@ def _(model: BaseHandler, prototype_data):
179170
model.prototype_data = prototype_data
180171

181172
return model
173+
174+
175+
@singledispatch
176+
def _prepare_data(pred_data) -> pd.DataFrame:
177+
"""Convert prototype to dataframe data
178+
179+
Parameters
180+
----------
181+
pred_data : pydantic.BaseModel
182+
User data from given to API endpoint
183+
184+
Returns
185+
-------
186+
pd.DataFrame
187+
BaseModel data translated into DataFrame
188+
"""
189+
190+
raise TypeError("Data should be list, dict, pd.DataFrame")
191+
192+
193+
@_prepare_data.register(pydantic.BaseModel)
194+
@_prepare_data.register(List)
195+
def _basemodel_list_data(pred_data):
196+
197+
return pd.DataFrame([dict(s) for s in pred_data])
198+
199+
200+
# @_prepare_data.register(dict)
201+
# def _dict_data(pred_data):
202+
# served_data = []
203+
# for key, value in pred_data:
204+
# served_data.append(value)
205+
206+
# return served_data

vetiver/handlers/sklearn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def handler_predict(self, input_data, check_prototype):
3434
Prediction from model
3535
"""
3636
if check_prototype:
37-
input_data = self._prepare_data(input_data)
37+
input_data = self._process_input(input_data)
3838

3939
if not check_prototype or isinstance(input_data, pd.DataFrame):
4040
prediction = self.model.predict(input_data)

vetiver/handlers/statsmodels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ def handler_predict(self, input_data, check_prototype):
4141
"""
4242
if not sm_exists:
4343
raise ImportError("Cannot import `statsmodels`")
44-
44+
4545
if check_prototype:
46-
input_data = self._prepare_data(input_data)
46+
input_data = self._process_input(input_data)
4747
if isinstance(input_data, (list, pd.DataFrame)):
4848
prediction = self.model.predict(input_data)
4949
else:

vetiver/handlers/torch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ def handler_predict(self, input_data, check_prototype):
4141
"""
4242
if not torch_exists:
4343
raise ImportError("Cannot import `torch`.")
44-
44+
4545
if check_prototype:
46-
input_data = self._prepare_data(input_data)
46+
input_data = self._process_input(input_data)
4747
input_data = np.array(input_data, dtype=np.array(self.prototype_data).dtype)
4848
prediction = self.model(torch.from_numpy(input_data))
4949

vetiver/handlers/xgboost.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def handler_predict(self, input_data, check_prototype):
4444
raise ImportError("Cannot import `xgboost`")
4545

4646
if check_prototype:
47-
input_data = self._prepare_data(input_data)
48-
47+
input_data = self._process_input(input_data)
48+
4949
if not isinstance(input_data, pd.DataFrame):
5050
try:
5151
input_data = pd.DataFrame(input_data)

vetiver/server.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Callable, Dict, List, Union
1+
from typing import Callable, List, Union
22
from urllib.parse import urljoin
33

44
import httpx
@@ -167,15 +167,7 @@ def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
167167
if self.check_prototype is True:
168168

169169
@self.app.post(urljoin("/", endpoint_name), name=endpoint_name)
170-
async def custom_endpoint(
171-
input_data: Union[self.model.prototype, List[self.model.prototype]]
172-
):
173-
174-
# if isinstance(input_data, List):
175-
# served_data = _batch_data(input_data)
176-
# else:
177-
# served_data = _prepare_data(input_data)
178-
170+
async def custom_endpoint(input_data: List[self.model.prototype]):
179171
new = endpoint_fx(input_data, **kw)
180172
return {endpoint_name: new.tolist()}
181173

@@ -261,13 +253,13 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw) -> pd.Da
261253
# TO DO: dispatch
262254

263255
if isinstance(data, pd.DataFrame):
264-
data_json = data.to_json(orient="records")
265-
response = requester.post(endpoint, data=data_json, **kw)
256+
response = requester.post(endpoint, data=data.to_json(orient="records"), **kw)
266257
elif isinstance(data, pd.Series):
267-
data_dict = data.to_json()
268-
response = requester.post(endpoint, data=data_dict, **kw)
269-
elif isinstance(data, dict):
258+
response = requester.post(endpoint, data=data.to_json(), **kw)
259+
elif isinstance(data, list):
270260
response = requester.post(endpoint, json=data, **kw)
261+
elif isinstance(data, dict):
262+
response = requester.post(endpoint, json=[data], **kw)
271263
else:
272264
response = requester.post(endpoint, json=data, **kw)
273265

vetiver/tests/test_predict.py

Lines changed: 0 additions & 98 deletions
This file was deleted.

0 commit comments

Comments
 (0)