Skip to content

Commit 46aefc7

Browse files
committed
update pytorch to accept no ptype
1 parent 2b7fd70 commit 46aefc7

3 files changed

Lines changed: 43 additions & 56 deletions

File tree

vetiver/handlers/pytorch_vt.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def handler_startup():
6565
"""
6666
...
6767

68-
def handler_predict(self, input_data, check_ptype):
68+
def handler_predict(self, input_data, check_ptype, **kw):
6969
"""Generates method for /predict endpoint in VetiverAPI
7070
7171
The `handler_predict` function executes at each API call. Use this
@@ -88,15 +88,10 @@ def handler_predict(self, input_data, check_ptype):
8888
prediction = self.model(torch.from_numpy(input_data))
8989

9090
# do not check ptype
91-
else:
92-
batch = True
93-
if not isinstance(input_data, list):
94-
batch = False
95-
input_data = input_data.split(",") # user delimiter ?
96-
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
97-
if not batch:
98-
input_data = input_data.reshape(1, -1)
99-
prediction = self.model(torch.from_numpy(input_data))
91+
else:
92+
input_data = torch.tensor(input_data)
93+
prediction = self.model(input_data)
94+
10095
else:
10196
raise ImportError("Cannot import `torch`.")
10297

vetiver/server.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from fastapi import FastAPI, Request
22
from fastapi.responses import HTMLResponse, RedirectResponse
3-
from fastapi.staticfiles import StaticFiles
43
from fastapi.openapi.utils import get_openapi
54
from fastapi import testclient
65

76
import uvicorn
87
import requests
98
import pandas as pd
10-
from typing import Callable, Optional, Union, List
9+
import numpy as np
10+
from typing import Callable, Union, List
1111

1212
from .vetiver_model import VetiverModel
1313
from .utils import _jupyter_nb
@@ -53,8 +53,37 @@ def docs_redirect():
5353
async def ping():
5454
return {"ping": "pong"}
5555

56+
if self.check_ptype is True:
57+
58+
@app.post("/predict/")
59+
async def prediction(
60+
input_data: Union[self.model.ptype, List[self.model.ptype]]
61+
):
62+
if isinstance(input_data, List):
63+
served_data = _batch_data(input_data)
64+
else:
65+
served_data = _prepare_data(input_data)
66+
67+
y = self.model.handler_predict(
68+
served_data, check_ptype=self.check_ptype
69+
)
70+
71+
return {"prediction": y.tolist()}
72+
73+
elif self.check_ptype is False:
74+
75+
@app.post("/predict/")
76+
async def prediction(input_data: Request):
77+
78+
y = await input_data.json()
79+
prediction = self.model.handler_predict(y, check_ptype=self.check_ptype)
80+
81+
return {"prediction": prediction.tolist()}
82+
else:
83+
raise ValueError("cannot determine `check_ptype`")
84+
5685
@app.get("/__docs__", response_class=HTMLResponse, include_in_schema=False)
57-
async def rapidoc_pg():
86+
async def rapidoc():
5887
return f"""
5988
<!doctype html>
6089
<html>
@@ -82,33 +111,6 @@ async def rapidoc_pg():
82111
</html>
83112
"""
84113

85-
if self.check_ptype == True:
86-
87-
@app.post("/predict/")
88-
async def prediction(
89-
input_data: Union[self.model.ptype, List[self.model.ptype]]
90-
):
91-
92-
if isinstance(input_data, List):
93-
served_data = _batch_data(input_data)
94-
else:
95-
served_data = _prepare_data(input_data)
96-
97-
y = self.model.handler_predict(
98-
served_data, check_ptype=self.check_ptype
99-
)
100-
101-
return {"prediction": y.tolist()}
102-
103-
else:
104-
105-
@app.post("/predict/")
106-
async def prediction(input_data: Request):
107-
y = await input_data.json()
108-
prediction = self.model.handler_predict(y, check_ptype=self.check_ptype)
109-
110-
return {"prediction": prediction.tolist()}
111-
112114
return app
113115

114116
def vetiver_post(

vetiver/tests/test_pytorch.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,17 @@ def test_torch_predict_ptype_error():
9999
assert response.status_code == 422, response.text # value is not a valid float
100100

101101

102-
def test_torch_predict_no_ptype():
102+
def test_torch_predict_no_ptype_error():
103103
torch.manual_seed(3)
104104
x_train, torch_model = _build_torch_v()
105-
v = VetiverModel(torch_model, model_name = "torch", ptype_data=x_train)
105+
v = VetiverModel(torch_model, model_name = "torch")
106106
v_api = VetiverAPI(v, check_ptype=False)
107107

108108
client = TestClient(v_api.app)
109-
data = "3.3"
109+
data =[[3.3], [3.3]]
110110
response = client.post("/predict/", json=data)
111111
assert response.status_code == 200, response.text
112-
assert response.json() == {"prediction":[[-4.060722351074219]]}, response.text
112+
assert response.json() == {"prediction":[[-4.060722351074219],[-4.060722351074219]]}, response.text
113113

114114

115115
def test_torch_predict_no_ptype_batch():
@@ -119,18 +119,8 @@ def test_torch_predict_no_ptype_batch():
119119
v_api = VetiverAPI(v, check_ptype=False)
120120

121121
client = TestClient(v_api.app)
122-
data = [["3.3"], ["3.3"]]
122+
data = [[3.3]]
123123
response = client.post("/predict/", json=data)
124124
assert response.status_code == 200, response.text
125-
assert response.json() == {"prediction":[[-4.060722351074219],[-4.060722351074219]]}, response.text
126-
127-
128-
def test_torch_predict_no_ptype_error():
129-
130-
x_train, torch_model = _build_torch_v()
131-
v = VetiverModel(torch_model, model_name = "torch")
132-
v_api = VetiverAPI(v, check_ptype=False)
133-
134-
client = TestClient(v_api.app)
135-
data = "bad"
125+
assert response.json() == {"prediction":[[-4.060722351074219]]}, response.text
136126

0 commit comments

Comments
 (0)