Skip to content

Commit 4ac2c89

Browse files
committed
batch for torch
1 parent 1705bdb commit 4ac2c89

2 files changed

Lines changed: 30 additions & 8 deletions

File tree

vetiver/handlers/pytorch_vt.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,16 @@ def handler_predict(self, input_data, check_ptype):
7272
if check_ptype == True:
7373
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
7474
prediction = self.model(torch.from_numpy(input_data))
75+
76+
# do not check ptype
7577
else:
76-
input_data = input_data.split(",") # user delimiter ?
78+
batch = True
79+
if not isinstance(input_data, list):
80+
batch = False
81+
input_data = input_data.split(",") # user delimiter ?
7782
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
78-
reshape_data = input_data.reshape(1, -1)
79-
prediction = self.model(torch.from_numpy(reshape_data))
83+
if not batch:
84+
input_data = input_data.reshape(1, -1)
85+
prediction = self.model(torch.from_numpy(input_data))
8086

8187
return prediction

vetiver/tests/test_pytorch.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1+
from cgitb import reset
12
import pytest
23

34
from vetiver.vetiver_model import VetiverModel
45
from vetiver import VetiverAPI
56
from fastapi.testclient import TestClient
67

8+
import torch
79
import torch.nn as nn
810
import numpy as np
911

10-
np.random.seed(500)
11-
1212

1313
def _build_torch_v():
1414

@@ -59,7 +59,7 @@ def test_vetiver_build():
5959

6060

6161
def test_torch_predict_ptype():
62-
62+
torch.manual_seed(3)
6363
x_train, torch_model = _build_torch_v()
6464
v = VetiverModel(torch_model, save_ptype=True, ptype_data=x_train)
6565
v_api = VetiverAPI(v)
@@ -69,10 +69,11 @@ def test_torch_predict_ptype():
6969
response = client.post("/predict/", json=data)
7070

7171
assert response.status_code == 200, response.text
72+
assert response.json() == {"prediction":[-4.060722351074219]}, response.text
7273

7374

7475
def test_torch_predict_ptype_batch():
75-
76+
torch.manual_seed(3)
7677
x_train, torch_model = _build_torch_v()
7778
v = VetiverModel(torch_model, save_ptype=True, ptype_data=x_train)
7879
v_api = VetiverAPI(v)
@@ -82,6 +83,7 @@ def test_torch_predict_ptype_batch():
8283
response = client.post("/predict/", json=data)
8384

8485
assert response.status_code == 200, response.text
86+
assert response.json() == {"prediction":[[-4.060722351074219],[-4.060722351074219]]}, response.text
8587

8688

8789
def test_torch_predict_ptype_error():
@@ -98,7 +100,7 @@ def test_torch_predict_ptype_error():
98100

99101

100102
def test_torch_predict_no_ptype():
101-
103+
torch.manual_seed(3)
102104
x_train, torch_model = _build_torch_v()
103105
v = VetiverModel(torch_model, save_ptype=False, ptype_data=x_train)
104106
v_api = VetiverAPI(v, check_ptype=False)
@@ -107,6 +109,20 @@ def test_torch_predict_no_ptype():
107109
data = "3.3"
108110
response = client.post("/predict/", json=data)
109111
assert response.status_code == 200, response.text
112+
assert response.json() == {"prediction":[[-4.060722351074219]]}, response.text
113+
114+
115+
def test_torch_predict_no_ptype_batch():
116+
torch.manual_seed(3)
117+
x_train, torch_model = _build_torch_v()
118+
v = VetiverModel(torch_model, save_ptype=False, ptype_data=x_train)
119+
v_api = VetiverAPI(v, check_ptype=False)
120+
121+
client = TestClient(v_api.app)
122+
data = [["3.3"], ["3.3"]]
123+
response = client.post("/predict/", json=data)
124+
assert response.status_code == 200, response.text
125+
assert response.json() == {"prediction":[[-4.060722351074219],[-4.060722351074219]]}, response.text
110126

111127

112128
def test_torch_predict_no_ptype_error():

0 commit comments

Comments
 (0)