Skip to content

Commit c2b7f8b

Browse files
committed
use helpers to convert api data to frame
1 parent a8472df commit c2b7f8b

8 files changed

Lines changed: 64 additions & 65 deletions

File tree

vetiver/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .handlers.torch import TorchHandler # noqa
1717
from .handlers.statsmodels import StatsmodelsHandler # noqa
1818
from .handlers.xgboost import XGBoostHandler # noqa
19+
from .helpers import api_data_to_frame # noqa
1920
from .rsconnect import deploy_rsconnect # noqa
2021
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa
2122
from .model_card import model_card # noqa

vetiver/handlers/base.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from functools import singledispatch
22
from contextlib import suppress
3-
import pandas as pd
4-
import pydantic
53

64
from ..prototype import vetiver_create_prototype
75
from ..meta import VetiverMeta
@@ -119,24 +117,6 @@ def handler_startup():
119117
"""
120118
...
121119

122-
def _process_input(self, pred_data) -> pd.DataFrame:
123-
"""Convert prototype to dataframe data
124-
125-
Parameters
126-
----------
127-
pred_data : pydantic.BaseModel
128-
User data from given to API endpoint
129-
130-
Returns
131-
-------
132-
pd.DataFrame
133-
BaseModel data translated into DataFrame
134-
"""
135-
136-
new = _prepare_data(pred_data)
137-
138-
return new
139-
140120
def handler_predict(self, input_data, check_prototype):
141121
"""Generates method for /predict endpoint in VetiverAPI
142122
@@ -169,37 +149,3 @@ def _(model: BaseHandler, prototype_data):
169149
model.prototype_data = prototype_data
170150

171151
return model
172-
173-
174-
@singledispatch
175-
def _prepare_data(pred_data) -> pd.DataFrame:
176-
"""Convert prototype to dataframe data
177-
178-
Parameters
179-
----------
180-
pred_data : pydantic.BaseModel
181-
User data from given to API endpoint
182-
183-
Returns
184-
-------
185-
pd.DataFrame
186-
BaseModel data translated into DataFrame
187-
"""
188-
189-
raise TypeError("Data should be list, dict, pd.DataFrame")
190-
191-
192-
@_prepare_data.register(pydantic.BaseModel)
193-
@_prepare_data.register(list)
194-
def _basemodel_list_data(pred_data):
195-
196-
return pd.DataFrame([dict(s) for s in pred_data])
197-
198-
199-
# @_prepare_data.register(dict)
200-
# def _dict_data(pred_data):
201-
# served_data = []
202-
# for key, value in pred_data:
203-
# served_data.append(value)
204-
205-
# return served_data

vetiver/handlers/sklearn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sklearn
33

44
from .base import BaseHandler
5+
from ..helpers import api_data_to_frame
56

67

78
class SKLearnHandler(BaseHandler):
@@ -34,7 +35,7 @@ def handler_predict(self, input_data, check_prototype):
3435
Prediction from model
3536
"""
3637
if check_prototype:
37-
input_data = self._process_input(input_data)
38+
input_data = api_data_to_frame(input_data)
3839

3940
if not check_prototype or isinstance(input_data, pd.DataFrame):
4041
prediction = self.model.predict(input_data)

vetiver/handlers/statsmodels.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pandas as pd
22

33
from .base import BaseHandler
4+
from ..helpers import api_data_to_frame
45

56
sm_exists = True
67
try:
@@ -43,7 +44,8 @@ def handler_predict(self, input_data, check_prototype):
4344
raise ImportError("Cannot import `statsmodels`")
4445

4546
if check_prototype:
46-
input_data = self._process_input(input_data)
47+
input_data = api_data_to_frame(input_data)
48+
4749
if isinstance(input_data, (list, pd.DataFrame)):
4850
prediction = self.model.predict(input_data)
4951
else:

vetiver/handlers/torch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from .base import BaseHandler
4+
from ..helpers import api_data_to_frame
45

56
torch_exists = True
67
try:
@@ -43,7 +44,7 @@ def handler_predict(self, input_data, check_prototype):
4344
raise ImportError("Cannot import `torch`.")
4445

4546
if check_prototype:
46-
input_data = self._process_input(input_data)
47+
input_data = api_data_to_frame(input_data)
4748
input_data = np.array(input_data, dtype=np.array(self.prototype_data).dtype)
4849
prediction = self.model(torch.from_numpy(input_data))
4950

vetiver/handlers/xgboost.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pandas as pd
22

33
from .base import BaseHandler
4+
from ..helpers import api_data_to_frame
45

56
xgb_exists = True
67
try:
@@ -44,7 +45,7 @@ def handler_predict(self, input_data, check_prototype):
4445
raise ImportError("Cannot import `xgboost`")
4546

4647
if check_prototype:
47-
input_data = self._process_input(input_data)
48+
input_data = api_data_to_frame(input_data)
4849

4950
if not isinstance(input_data, pd.DataFrame):
5051
try:

vetiver/helpers.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from functools import singledispatch
2+
import pandas as pd
3+
import pydantic
4+
5+
6+
@singledispatch
7+
def api_data_to_frame(pred_data) -> pd.DataFrame:
8+
"""Convert prototype to dataframe data
9+
10+
Parameters
11+
----------
12+
pred_data : pydantic.BaseModel
13+
User data from given to API endpoint
14+
15+
Returns
16+
-------
17+
pd.DataFrame
18+
BaseModel data translated into DataFrame
19+
"""
20+
21+
raise TypeError("Data should be list, pydantic.BaseModel, pd.DataFrame")
22+
23+
24+
@api_data_to_frame.register(pydantic.BaseModel)
25+
@api_data_to_frame.register(list)
26+
def _(pred_data):
27+
28+
return pd.DataFrame([dict(s) for s in pred_data])
29+
30+
31+
@api_data_to_frame.register(dict)
32+
def _dict(pred_data):
33+
return api_data_to_frame([pred_data])
34+
35+
36+
# possible other names
37+
# prototype_to_frame
38+
# prototype_to_data
39+
# prototype_to_dataframe
40+
# prototype_to_datatype
41+
# prototype_to_type
42+
# api_data_to_
43+
# json_to_
44+
# server_data_to_
45+
# transport_data_to_
46+
# request
47+
# query
48+
# transit
49+
# interchange/exchange
50+
# transfer
51+
# through

vetiver/tests/test_add_endpoint.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from fastapi.testclient import TestClient
66

77
from vetiver import mock, VetiverModel, VetiverAPI
8+
from vetiver.helpers import api_data_to_frame
89
import vetiver
910

1011

@@ -25,12 +26,7 @@ def vetiver_model():
2526

2627

2728
def sum_values(x):
28-
return pd.DataFrame([dict(s) for s in x]).sum()
29-
30-
31-
def sum_dict(x):
32-
x = pd.DataFrame(x)
33-
return x.sum()
29+
return api_data_to_frame(x).sum()
3430

3531

3632
@pytest.fixture
@@ -49,7 +45,7 @@ def vetiver_client(vetiver_model): # With check_prototype=True
4945
def vetiver_client_check_ptype_false(vetiver_model): # With check_prototype=False
5046

5147
app = VetiverAPI(vetiver_model, check_prototype=False)
52-
app.vetiver_post(sum_dict, "sum")
48+
app.vetiver_post(sum_values, "sum")
5349

5450
app.app.root_path = "/sum"
5551
client = TestClient(app.app)

0 commit comments

Comments
 (0)