Skip to content

Commit 2e99db2

Browse files
committed
move api_data_to_frame to server
1 parent 8f686e5 commit 2e99db2

6 files changed

Lines changed: 18 additions & 14 deletions

File tree

vetiver/handlers/sklearn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import sklearn
33

44
from .base import BaseHandler
5-
from ..helpers import api_data_to_frame
65

76

87
class SKLearnHandler(BaseHandler):
@@ -34,12 +33,14 @@ def handler_predict(self, input_data, check_prototype):
3433
prediction
3534
Prediction from model
3635
"""
37-
if check_prototype:
38-
input_data = api_data_to_frame(input_data)
36+
# if check_prototype:
37+
# input_data = api_data_to_frame(input_data)
3938

4039
if not check_prototype or isinstance(input_data, pd.DataFrame):
4140
prediction = self.model.predict(input_data)
4241
else:
4342
prediction = self.model.predict([input_data])
4443

44+
# some sort of post-prediction/pre send back to user hook
45+
4546
return prediction.tolist()

vetiver/handlers/statsmodels.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pandas as pd
22

33
from .base import BaseHandler
4-
from ..helpers import api_data_to_frame
54

65
sm_exists = True
76
try:
@@ -43,8 +42,8 @@ def handler_predict(self, input_data, check_prototype):
4342
if not sm_exists:
4443
raise ImportError("Cannot import `statsmodels`")
4544

46-
if check_prototype:
47-
input_data = api_data_to_frame(input_data)
45+
# if check_prototype:
46+
# input_data = api_data_to_frame(input_data)
4847

4948
if isinstance(input_data, (list, pd.DataFrame)):
5049
prediction = self.model.predict(input_data)

vetiver/handlers/torch.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22

33
from .base import BaseHandler
4-
from ..helpers import api_data_to_frame
54

65
torch_exists = True
76
try:
@@ -44,7 +43,7 @@ def handler_predict(self, input_data, check_prototype):
4443
raise ImportError("Cannot import `torch`.")
4544

4645
if check_prototype:
47-
input_data = api_data_to_frame(input_data)
46+
# input_data = api_data_to_frame(input_data)
4847
input_data = np.array(input_data, dtype=np.array(self.prototype_data).dtype)
4948
prediction = self.model(torch.from_numpy(input_data))
5049

vetiver/handlers/xgboost.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import pandas as pd
22

33
from .base import BaseHandler
4-
from ..helpers import api_data_to_frame
54

65
xgb_exists = True
76
try:
@@ -44,8 +43,8 @@ def handler_predict(self, input_data, check_prototype):
4443
if not xgb_exists:
4544
raise ImportError("Cannot import `xgboost`")
4645

47-
if check_prototype:
48-
input_data = api_data_to_frame(input_data)
46+
# if check_prototype:
47+
# input_data = api_data_to_frame(input_data)
4948

5049
if not isinstance(input_data, pd.DataFrame):
5150
try:

vetiver/server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .utils import _jupyter_nb
1717
from .vetiver_model import VetiverModel
1818
from .meta import VetiverMeta
19+
from .helpers import api_data_to_frame
1920

2021

2122
class VetiverAPI:
@@ -175,7 +176,8 @@ def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
175176

176177
@self.app.post(urljoin("/", endpoint_name), name=endpoint_name)
177178
async def custom_endpoint(input_data: List[self.model.prototype]):
178-
predictions = endpoint_fx(input_data, **kw)
179+
_to_frame = api_data_to_frame(input_data)
180+
predictions = endpoint_fx(_to_frame, **kw)
179181
return {endpoint_name: predictions}
180182

181183
else:

vetiver/tests/test_add_endpoint.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ def vetiver_model():
2626

2727

2828
def sum_values(x):
29-
return api_data_to_frame(x).sum()
29+
return x.sum().to_list()
30+
31+
32+
def sum_values_no_prototype(x):
33+
return api_data_to_frame(x).sum().to_list()
3034

3135

3236
@pytest.fixture
@@ -45,7 +49,7 @@ def vetiver_client(vetiver_model): # With check_prototype=True
4549
def vetiver_client_check_ptype_false(vetiver_model): # With check_prototype=False
4650

4751
app = VetiverAPI(vetiver_model, check_prototype=False)
48-
app.vetiver_post(sum_values, "sum")
52+
app.vetiver_post(sum_values_no_prototype, "sum")
4953

5054
app.app.root_path = "/sum"
5155
client = TestClient(app.app)

0 commit comments

Comments
 (0)