Skip to content

Commit 2edf364

Browse files
committed
no ptype and batch tests
1 parent 5b6854b commit 2edf364

3 files changed

Lines changed: 17 additions & 13 deletions

File tree

vetiver/handlers/statsmodels.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,8 @@ def handler_predict(self, input_data, check_ptype):
5858
prediction
5959
Prediction from model
6060
"""
61-
if not check_ptype:
62-
input_data = pd.DataFrame(input_data)
63-
if isinstance(input_data, pd.DataFrame):
61+
62+
if isinstance(input_data, (list, pd.DataFrame)):
6463
prediction = self.model.predict(input_data)
6564
else:
6665
prediction = self.model.predict([input_data])

vetiver/server.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,8 @@ async def prediction(
8383
@app.post("/predict")
8484
async def prediction(input_data: Request):
8585
y = await input_data.json()
86-
from io import BytesIO # noqa
8786

88-
df = pd.read_csv(BytesIO(y))
89-
prediction = self.model.handler_predict(
90-
df, check_ptype=self.check_ptype
91-
)
87+
prediction = self.model.handler_predict(y, check_ptype=self.check_ptype)
9288

9389
return {"prediction": prediction.tolist()}
9490

vetiver/tests/test_statsmodels.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88

99
import numpy as np # noqa
10+
import pandas as pd # noqa
1011
from fastapi.testclient import TestClient # noqa
1112

1213
import vetiver # noqa
@@ -25,20 +26,28 @@ def build_sm():
2526
def test_vetiver_build(build_sm):
2627
api = vetiver.VetiverAPI(build_sm)
2728
client = TestClient(api.app)
28-
data = {"B": 0, "C": 0, "D": 0}
29+
data = [{"B": 0, "C": 0, "D": 0}]
2930

3031
response = vetiver.predict(endpoint=client, data=data)
3132

3233
assert response.iloc[0, 0] == 0.0
3334
assert len(response) == 1
3435

3536

37+
def test_batch(build_sm):
38+
api = vetiver.VetiverAPI(build_sm)
39+
client = TestClient(api.app)
40+
data = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD"))
41+
42+
response = vetiver.predict(endpoint=client, data=data)
43+
44+
assert len(response) == 100
45+
46+
3647
def test_no_ptype(build_sm):
37-
v = build_sm
38-
v.ptype = None
39-
api = vetiver.VetiverAPI(v)
48+
api = vetiver.VetiverAPI(build_sm, check_ptype=False)
4049
client = TestClient(api.app)
41-
data = 0, 0, 0
50+
data = [0, 0, 0]
4251

4352
response = vetiver.predict(endpoint=client, data=data)
4453

0 commit comments

Comments
 (0)