Skip to content

Commit 4da0740

Browse files
committed
move data handling to handler
1 parent 7a2eaf8 commit 4da0740

8 files changed

Lines changed: 46 additions & 25 deletions

File tree

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ lint:
5757
test: clean-test
5858
pytest -m 'not rsc_test and not docker'
5959

60+
test-pdb: clean-test
61+
pytest -m 'not rsc_test and not docker' --pdb
62+
6063
test-rsc: clean-test
6164
pytest
6265

vetiver/handlers/base.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from vetiver.handlers import base
1+
from typing import List
22
from functools import singledispatch
33
from contextlib import suppress
4+
import pandas as pd
45

56
from ..prototype import vetiver_create_prototype
67
from ..meta import VetiverMeta
@@ -118,6 +119,30 @@ def handler_startup():
118119
"""
119120
...
120121

122+
def _prepare_data(self, pred_data) -> pd.DataFrame:
123+
"""Convert prototype to dataframe data
124+
125+
Parameters
126+
----------
127+
prototype_data : pd.DataFrame, np.ndarray, or None
128+
Training data to create prototype
129+
130+
Returns
131+
-------
132+
prototype : pd.DataFrame or None
133+
Zero-row DataFrame for storing data types
134+
"""
135+
if isinstance(pred_data, List):
136+
columns = pred_data[0].dict().keys()
137+
data = [line.dict() for line in pred_data]
138+
served_data = pd.DataFrame(data, columns=columns)
139+
else:
140+
served_data = []
141+
for key, value in pred_data:
142+
served_data.append(value)
143+
144+
return served_data
145+
121146
def handler_predict(self, input_data, check_prototype):
122147
"""Generates method for /predict endpoint in VetiverAPI
123148
@@ -145,7 +170,7 @@ def handler_predict(self, input_data, check_prototype):
145170

146171

147172
@create_handler.register
148-
def _(model: base.BaseHandler, prototype_data):
173+
def _(model: BaseHandler, prototype_data):
149174
if model.prototype_data is None and prototype_data is not None:
150175
model.prototype_data = prototype_data
151176

vetiver/handlers/sklearn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ def handler_predict(self, input_data, check_prototype):
3333
prediction
3434
Prediction from model
3535
"""
36+
if check_prototype:
37+
input_data = self._prepare_data(input_data)
3638

3739
if not check_prototype or isinstance(input_data, pd.DataFrame):
3840
prediction = self.model.predict(input_data)

vetiver/handlers/statsmodels.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def handler_predict(self, input_data, check_prototype):
4141
"""
4242
if not sm_exists:
4343
raise ImportError("Cannot import `statsmodels`")
44-
44+
45+
if check_prototype:
46+
input_data = self._prepare_data(input_data)
4547
if isinstance(input_data, (list, pd.DataFrame)):
4648
prediction = self.model.predict(input_data)
4749
else:

vetiver/handlers/torch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def handler_predict(self, input_data, check_prototype):
4141
"""
4242
if not torch_exists:
4343
raise ImportError("Cannot import `torch`.")
44+
4445
if check_prototype:
46+
input_data = self._prepare_data(input_data)
4547
input_data = np.array(input_data, dtype=np.array(self.prototype_data).dtype)
4648
prediction = self.model(torch.from_numpy(input_data))
4749

vetiver/handlers/xgboost.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def handler_predict(self, input_data, check_prototype):
4343
if not xgb_exists:
4444
raise ImportError("Cannot import `xgboost`")
4545

46+
if check_prototype:
47+
input_data = self._prepare_data(input_data)
48+
4649
if not isinstance(input_data, pd.DataFrame):
4750
try:
4851
input_data = pd.DataFrame(input_data)

vetiver/server.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,12 @@ async def custom_endpoint(
171171
input_data: Union[self.model.prototype, List[self.model.prototype]]
172172
):
173173

174-
if isinstance(input_data, List):
175-
served_data = _batch_data(input_data)
176-
else:
177-
served_data = _prepare_data(input_data)
174+
# if isinstance(input_data, List):
175+
# served_data = _batch_data(input_data)
176+
# else:
177+
# served_data = _prepare_data(input_data)
178178

179-
new = endpoint_fx(served_data, **kw)
179+
new = endpoint_fx(input_data, **kw)
180180
return {endpoint_name: new.tolist()}
181181

182182
else:
@@ -287,22 +287,6 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw) -> pd.Da
287287
return response_df
288288

289289

290-
def _prepare_data(pred_data: Dict[str, Any]) -> List[Any]:
291-
served_data = []
292-
for key, value in pred_data:
293-
served_data.append(value)
294-
return served_data
295-
296-
297-
def _batch_data(pred_data: List[Any]) -> pd.DataFrame:
298-
columns = pred_data[0].dict().keys()
299-
300-
data = [line.dict() for line in pred_data]
301-
302-
served_data = pd.DataFrame(data, columns=columns)
303-
return served_data
304-
305-
306290
def vetiver_endpoint(url: str = "http://127.0.0.1:8000/predict") -> str:
307291
"""Wrap url where VetiverModel will be deployed
308292

vetiver/tests/test_add_endpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def vetiver_model():
2525

2626

2727
def sum_values(x):
28-
return x.sum()
28+
return pd.DataFrame([dict(s) for s in x]).sum()
2929

3030

3131
def sum_dict(x):

0 commit comments

Comments
 (0)