Skip to content

Commit 92b9c7c

Browse files
authored
Merge pull request #13 from isabelizimm/dev-ptype
support batch processing for `sklearn` and `pytorch`
2 parents 0c5953b + 4ac2c89 commit 92b9c7c

5 files changed

Lines changed: 91 additions & 21 deletions

File tree

vetiver/handlers/pytorch_vt.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,16 @@ def handler_predict(self, input_data, check_ptype):
7272
if check_ptype == True:
7373
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
7474
prediction = self.model(torch.from_numpy(input_data))
75+
76+
# do not check ptype
7577
else:
76-
input_data = input_data.split(",") # user delimiter ?
78+
batch = True
79+
if not isinstance(input_data, list):
80+
batch = False
81+
input_data = input_data.split(",") # user delimiter ?
7782
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
78-
reshape_data = input_data.reshape(1, -1)
79-
prediction = self.model(torch.from_numpy(reshape_data))
83+
if not batch:
84+
input_data = input_data.reshape(1, -1)
85+
prediction = self.model(torch.from_numpy(input_data))
8086

8187
return prediction

vetiver/handlers/sklearn_vt.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from ..ptype import _vetiver_create_ptype
2-
import sklearn
2+
3+
import pandas as pd
34
import numpy as np
45

56
class SKLearnHandler:
@@ -70,11 +71,16 @@ def handler_predict(self, input_data, check_ptype):
7071
Prediction from model
7172
"""
7273
if check_ptype == True:
73-
prediction = self.model.predict([input_data])
74+
if isinstance(input_data, pd.DataFrame):
75+
prediction = self.model.predict(input_data)
76+
else:
77+
prediction = self.model.predict([input_data])
78+
79+
# do not check ptype
7480
else:
75-
input_data = input_data.split(",") # user delimiter ?
76-
input_data = np.asarray(input_data)
77-
reshape_data = input_data.reshape(1, -1)
78-
prediction = self.model.predict(reshape_data)
81+
if not isinstance(input_data, list):
82+
input_data = [input_data.split(",")] # user delimiter ?
83+
84+
prediction = self.model.predict(input_data)
7985

8086
return prediction

vetiver/server.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from fastapi import FastAPI, Request
22
from fastapi.responses import HTMLResponse
33
import uvicorn
4-
from typing import Callable, Optional
4+
from typing import Callable, Optional, Union, List
55
import requests
66
import pandas as pd
77

@@ -81,9 +81,12 @@ async def rapidoc():
8181
if self.check_ptype == True:
8282

8383
@app.post("/predict/")
84-
async def prediction(input_data: self.model.ptype):
85-
86-
served_data = _prepare_data(input_data)
84+
async def prediction(input_data: Union[self.model.ptype, List[self.model.ptype]]):
85+
86+
if isinstance(input_data, List):
87+
served_data = _batch_data(input_data)
88+
else:
89+
served_data = _prepare_data(input_data)
8790

8891
y = self.model.handler_predict(served_data, check_ptype=self.check_ptype)
8992

@@ -165,6 +168,16 @@ def _prepare_data(pred_data):
165168
served_data.append(value)
166169
return served_data
167170

171+
def _batch_data(pred_data):
172+
columns = pred_data[0].dict().keys()
173+
174+
data = [line.dict() for line in pred_data]
175+
print(data)
176+
177+
served_data = pd.DataFrame(data, columns=columns)
178+
return served_data
179+
180+
168181

169182
def vetiver_endpoint(url="http://127.0.0.1:8000/predict"):
170183
"""Wrap url

vetiver/tests/test_pytorch.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1+
from cgitb import reset
12
import pytest
23

34
from vetiver.vetiver_model import VetiverModel
45
from vetiver import VetiverAPI
56
from fastapi.testclient import TestClient
67

8+
import torch
79
import torch.nn as nn
810
import numpy as np
911

10-
np.random.seed(500)
11-
1212

1313
def _build_torch_v():
1414

@@ -59,7 +59,7 @@ def test_vetiver_build():
5959

6060

6161
def test_torch_predict_ptype():
62-
62+
torch.manual_seed(3)
6363
x_train, torch_model = _build_torch_v()
6464
v = VetiverModel(torch_model, save_ptype=True, ptype_data=x_train)
6565
v_api = VetiverAPI(v)
@@ -69,6 +69,22 @@ def test_torch_predict_ptype():
6969
response = client.post("/predict/", json=data)
7070

7171
assert response.status_code == 200, response.text
72+
assert response.json() == {"prediction":[-4.060722351074219]}, response.text
73+
74+
75+
def test_torch_predict_ptype_batch():
76+
torch.manual_seed(3)
77+
x_train, torch_model = _build_torch_v()
78+
v = VetiverModel(torch_model, save_ptype=True, ptype_data=x_train)
79+
v_api = VetiverAPI(v)
80+
81+
client = TestClient(v_api.app)
82+
data = [{"0": 3.3}, {"0": 3.3}]
83+
response = client.post("/predict/", json=data)
84+
85+
assert response.status_code == 200, response.text
86+
assert response.json() == {"prediction":[[-4.060722351074219],[-4.060722351074219]]}, response.text
87+
7288

7389
def test_torch_predict_ptype_error():
7490

@@ -80,19 +96,33 @@ def test_torch_predict_ptype_error():
8096
data = {"0": "bad"}
8197
response = client.post("/predict/", json=data)
8298

83-
assert response.status_code == 422, response.text # value is not a valid float
99+
assert response.status_code == 422, response.text # value is not a valid float
84100

85101

86102
def test_torch_predict_no_ptype():
103+
torch.manual_seed(3)
104+
x_train, torch_model = _build_torch_v()
105+
v = VetiverModel(torch_model, save_ptype=False, ptype_data=x_train)
106+
v_api = VetiverAPI(v, check_ptype=False)
107+
108+
client = TestClient(v_api.app)
109+
data = "3.3"
110+
response = client.post("/predict/", json=data)
111+
assert response.status_code == 200, response.text
112+
assert response.json() == {"prediction":[[-4.060722351074219]]}, response.text
113+
87114

115+
def test_torch_predict_no_ptype_batch():
116+
torch.manual_seed(3)
88117
x_train, torch_model = _build_torch_v()
89118
v = VetiverModel(torch_model, save_ptype=False, ptype_data=x_train)
90119
v_api = VetiverAPI(v, check_ptype=False)
91120

92121
client = TestClient(v_api.app)
93-
data = '3.3'
122+
data = [["3.3"], ["3.3"]]
94123
response = client.post("/predict/", json=data)
95124
assert response.status_code == 200, response.text
125+
assert response.json() == {"prediction":[[-4.060722351074219],[-4.060722351074219]]}, response.text
96126

97127

98128
def test_torch_predict_no_ptype_error():
@@ -102,6 +132,6 @@ def test_torch_predict_no_ptype_error():
102132
v_api = VetiverAPI(v, check_ptype=False)
103133

104134
client = TestClient(v_api.app)
105-
data = 'bad'
135+
data = "bad"
106136
with pytest.raises(ValueError):
107137
client.post("/predict/", json=data)

vetiver/tests/test_sklearn.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ def test_predict_endpoint_ptype():
3131
assert response.status_code == 200, response.text
3232
assert response.json() == {"prediction": [44.47]}, response.json()
3333

34+
def test_predict_endpoint_ptype_batch():
35+
np.random.seed(500)
36+
client = TestClient(_start_application(save_ptype=True).app)
37+
data = [{"B": 0, "C": 0, "D": 0},{"B": 0, "C": 0, "D": 0}]
38+
response = client.post("/predict/", json=data)
39+
assert response.status_code == 200, response.text
40+
assert response.json() == {"prediction": [44.47, 44.47]}, response.json()
41+
3442

3543
def test_predict_endpoint_ptype_error():
3644
np.random.seed(500)
@@ -48,10 +56,17 @@ def test_predict_endpoint_no_ptype():
4856
assert response.status_code == 200, response.text
4957
assert response.json() == {"prediction": [44.47]}, response.json()
5058

59+
def test_predict_endpoint_no_ptype_batch():
60+
np.random.seed(500)
61+
client = TestClient(_start_application(save_ptype=False).app)
62+
data = [['0,0,0'],['0,0,0']]
63+
response = client.post("/predict/", json=data)
64+
assert response.status_code == 200, response.text
65+
assert response.json() == {"prediction": [44.47, 44.47]}, response.json()
5166

5267
def test_predict_endpoint_no_ptype_error():
5368
np.random.seed(500)
5469
client = TestClient(_start_application(save_ptype=False).app)
55-
data = ['hell0',9,32.0]
56-
with pytest.raises(AttributeError):
70+
data = {'hell0',9,32.0}
71+
with pytest.raises(TypeError):
5772
client.post("/predict/", json=data)

0 commit comments

Comments
 (0)