Skip to content

Commit 1705bdb

Browse files
committed
batch processing for sklearn
1 parent 0c5953b commit 1705bdb

4 files changed

Lines changed: 63 additions & 15 deletions

File tree

vetiver/handlers/sklearn_vt.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from ..ptype import _vetiver_create_ptype
2-
import sklearn
2+
3+
import pandas as pd
34
import numpy as np
45

56
class SKLearnHandler:
@@ -70,11 +71,16 @@ def handler_predict(self, input_data, check_ptype):
7071
Prediction from model
7172
"""
7273
if check_ptype == True:
73-
prediction = self.model.predict([input_data])
74+
if isinstance(input_data, pd.DataFrame):
75+
prediction = self.model.predict(input_data)
76+
else:
77+
prediction = self.model.predict([input_data])
78+
79+
# do not check ptype
7480
else:
75-
input_data = input_data.split(",") # user delimiter ?
76-
input_data = np.asarray(input_data)
77-
reshape_data = input_data.reshape(1, -1)
78-
prediction = self.model.predict(reshape_data)
81+
if not isinstance(input_data, list):
82+
input_data = [input_data.split(",")] # user delimiter ?
83+
84+
prediction = self.model.predict(input_data)
7985

8086
return prediction

vetiver/server.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from fastapi import FastAPI, Request
22
from fastapi.responses import HTMLResponse
33
import uvicorn
4-
from typing import Callable, Optional
4+
from typing import Callable, Optional, Union, List
55
import requests
66
import pandas as pd
77

@@ -81,9 +81,12 @@ async def rapidoc():
8181
if self.check_ptype == True:
8282

8383
@app.post("/predict/")
84-
async def prediction(input_data: self.model.ptype):
85-
86-
served_data = _prepare_data(input_data)
84+
async def prediction(input_data: Union[self.model.ptype, List[self.model.ptype]]):
85+
86+
if isinstance(input_data, List):
87+
served_data = _batch_data(input_data)
88+
else:
89+
served_data = _prepare_data(input_data)
8790

8891
y = self.model.handler_predict(served_data, check_ptype=self.check_ptype)
8992

@@ -165,6 +168,16 @@ def _prepare_data(pred_data):
165168
served_data.append(value)
166169
return served_data
167170

171+
def _batch_data(pred_data):
172+
columns = pred_data[0].dict().keys()
173+
174+
data = [line.dict() for line in pred_data]
175+
print(data)
176+
177+
served_data = pd.DataFrame(data, columns=columns)
178+
return served_data
179+
180+
168181

169182
def vetiver_endpoint(url="http://127.0.0.1:8000/predict"):
170183
"""Wrap url

vetiver/tests/test_pytorch.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,20 @@ def test_torch_predict_ptype():
7070

7171
assert response.status_code == 200, response.text
7272

73+
74+
def test_torch_predict_ptype_batch():
75+
76+
x_train, torch_model = _build_torch_v()
77+
v = VetiverModel(torch_model, save_ptype=True, ptype_data=x_train)
78+
v_api = VetiverAPI(v)
79+
80+
client = TestClient(v_api.app)
81+
data = [{"0": 3.3}, {"0": 3.3}]
82+
response = client.post("/predict/", json=data)
83+
84+
assert response.status_code == 200, response.text
85+
86+
7387
def test_torch_predict_ptype_error():
7488

7589
x_train, torch_model = _build_torch_v()
@@ -80,7 +94,7 @@ def test_torch_predict_ptype_error():
8094
data = {"0": "bad"}
8195
response = client.post("/predict/", json=data)
8296

83-
assert response.status_code == 422, response.text # value is not a valid float
97+
assert response.status_code == 422, response.text # value is not a valid float
8498

8599

86100
def test_torch_predict_no_ptype():
@@ -90,7 +104,7 @@ def test_torch_predict_no_ptype():
90104
v_api = VetiverAPI(v, check_ptype=False)
91105

92106
client = TestClient(v_api.app)
93-
data = '3.3'
107+
data = "3.3"
94108
response = client.post("/predict/", json=data)
95109
assert response.status_code == 200, response.text
96110

@@ -102,6 +116,6 @@ def test_torch_predict_no_ptype_error():
102116
v_api = VetiverAPI(v, check_ptype=False)
103117

104118
client = TestClient(v_api.app)
105-
data = 'bad'
119+
data = "bad"
106120
with pytest.raises(ValueError):
107121
client.post("/predict/", json=data)

vetiver/tests/test_sklearn.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ def test_predict_endpoint_ptype():
3131
assert response.status_code == 200, response.text
3232
assert response.json() == {"prediction": [44.47]}, response.json()
3333

34+
def test_predict_endpoint_ptype_batch():
35+
np.random.seed(500)
36+
client = TestClient(_start_application(save_ptype=True).app)
37+
data = [{"B": 0, "C": 0, "D": 0},{"B": 0, "C": 0, "D": 0}]
38+
response = client.post("/predict/", json=data)
39+
assert response.status_code == 200, response.text
40+
assert response.json() == {"prediction": [44.47, 44.47]}, response.json()
41+
3442

3543
def test_predict_endpoint_ptype_error():
3644
np.random.seed(500)
@@ -48,10 +56,17 @@ def test_predict_endpoint_no_ptype():
4856
assert response.status_code == 200, response.text
4957
assert response.json() == {"prediction": [44.47]}, response.json()
5058

59+
def test_predict_endpoint_no_ptype_batch():
60+
np.random.seed(500)
61+
client = TestClient(_start_application(save_ptype=False).app)
62+
data = [['0,0,0'],['0,0,0']]
63+
response = client.post("/predict/", json=data)
64+
assert response.status_code == 200, response.text
65+
assert response.json() == {"prediction": [44.47, 44.47]}, response.json()
5166

5267
def test_predict_endpoint_no_ptype_error():
5368
np.random.seed(500)
5469
client = TestClient(_start_application(save_ptype=False).app)
55-
data = ['hell0',9,32.0]
56-
with pytest.raises(AttributeError):
70+
data = {'hell0',9,32.0}
71+
with pytest.raises(TypeError):
5772
client.post("/predict/", json=data)

0 commit comments

Comments
 (0)