Skip to content

Commit 61fa10f

Browse files
committed
check_ptype to check_prototype
1 parent 0143698 commit 61fa10f

12 files changed

Lines changed: 48 additions & 38 deletions

vetiver/handlers/sklearn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def create_meta(
3333

3434
return meta
3535

36-
def handler_predict(self, input_data, check_ptype):
36+
def handler_predict(self, input_data, check_prototype):
3737
"""Generates method for /predict endpoint in VetiverAPI
3838
3939
The `handler_predict` function executes at each API call. Use this
@@ -51,7 +51,7 @@ def handler_predict(self, input_data, check_ptype):
5151
Prediction from model
5252
"""
5353

54-
if not check_ptype or isinstance(input_data, pd.DataFrame):
54+
if not check_prototype or isinstance(input_data, pd.DataFrame):
5555
prediction = self.model.predict(input_data)
5656
else:
5757
prediction = self.model.predict([input_data])

vetiver/handlers/statsmodels.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ class StatsmodelsHandler(BaseHandler):
2121

2222
model_class = staticmethod(lambda: statsmodels.base.wrapper.ResultsWrapper)
2323

24-
def __init__(self, model, ptype_data):
25-
super().__init__(model, ptype_data)
26-
2724
def describe(self):
2825
"""Create description for statsmodels model"""
2926
desc = f"Statsmodels {self.model.__class__} model."
@@ -41,7 +38,7 @@ def create_meta(
4138

4239
return meta
4340

44-
def handler_predict(self, input_data, check_ptype):
41+
def handler_predict(self, input_data, check_prototype):
4542
"""Generates method for /predict endpoint in VetiverAPI
4643
4744
The `handler_predict` function executes at each API call. Use this

vetiver/handlers/torch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def create_meta(
3838

3939
return meta
4040

41-
def handler_predict(self, input_data, check_ptype):
41+
def handler_predict(self, input_data, check_prototype):
4242
"""Generates method for /predict endpoint in VetiverAPI
4343
4444
The `handler_predict` function executes at each API call. Use this
@@ -57,7 +57,7 @@ def handler_predict(self, input_data, check_ptype):
5757
"""
5858
if not torch_exists:
5959
raise ImportError("Cannot import `torch`.")
60-
if check_ptype:
60+
if check_prototype:
6161
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
6262
prediction = self.model(torch.from_numpy(input_data))
6363

vetiver/handlers/xgboost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def create_meta(
3838

3939
return meta
4040

41-
def handler_predict(self, input_data, check_ptype):
41+
def handler_predict(self, input_data, check_prototype):
4242
"""Generates method for /predict endpoint in VetiverAPI
4343
4444
The `handler_predict` function executes at each API call. Use this

vetiver/server.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, List, Union, Any, Dict
1+
from typing import Any, Callable, Dict, List, Union
22
from urllib.parse import urljoin
33

44
import httpx
@@ -8,6 +8,7 @@
88
from fastapi import FastAPI, Request, testclient
99
from fastapi.openapi.utils import get_openapi
1010
from fastapi.responses import HTMLResponse, RedirectResponse
11+
from warnings import warn
1112

1213
from .utils import _jupyter_nb
1314
from .vetiver_model import VetiverModel
@@ -20,8 +21,10 @@ class VetiverAPI:
2021
----------
2122
model : VetiverModel
2223
Model to be deployed in API
23-
check_ptype : bool
24+
check_prototype : bool
2425
Determine if data prototype should be enforced
26+
check_ptype : bool
27+
Deprecated in favor of check_prototype
2528
app_factory :
2629
Type of API to be deployed
2730
@@ -31,22 +34,32 @@ class VetiverAPI:
3134
>>> X, y = vt.get_mock_data()
3235
>>> model = vt.get_mock_model().fit(X, y)
3336
>>> v = vt.VetiverModel(model = model, model_name = "my_model", prototype_data = X)
34-
>>> v_api = vt.VetiverAPI(model = v, check_ptype = True)
37+
>>> v_api = vt.VetiverAPI(model = v, check_prototype = True)
3538
"""
3639

3740
app = None
3841

3942
def __init__(
4043
self,
4144
model: VetiverModel,
42-
check_ptype: bool = True,
45+
check_prototype: bool = True,
46+
check_ptype: bool = None,
4347
app_factory=FastAPI,
4448
) -> None:
4549
self.model = model
46-
self.check_ptype = check_ptype
4750
self.app_factory = app_factory
4851
self.app = app_factory()
4952

53+
if check_ptype is not None:
54+
check_prototype = check_ptype
55+
warn(
56+
"argument for checking input data prototype has changed to "
57+
"check_prototype, from check_ptype",
58+
DeprecationWarning,
59+
stacklevel=2,
60+
)
61+
self.check_prototype = check_prototype
62+
5063
self._init_app()
5164

5265
def _init_app(self):
@@ -71,7 +84,7 @@ async def ping():
7184
return {"ping": "pong"}
7285

7386
self.vetiver_post(
74-
self.model.handler_predict, "predict", check_ptype=self.check_ptype
87+
self.model.handler_predict, "predict", check_prototype=self.check_prototype
7588
)
7689

7790
@app.get("/__docs__", response_class=HTMLResponse, include_in_schema=False)
@@ -130,15 +143,15 @@ def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
130143
>>> X, y = vt.get_mock_data()
131144
>>> model = vt.get_mock_model().fit(X, y)
132145
>>> v = vt.VetiverModel(model = model, model_name = "model", prototype_data = X)
133-
>>> v_api = vt.VetiverAPI(model = v, check_ptype = True)
146+
>>> v_api = vt.VetiverAPI(model = v, check_prototype = True)
134147
>>> def sum_values(x):
135148
... return x.sum()
136149
>>> v_api.vetiver_post(sum_values, "sums")
137150
"""
138151
if not endpoint_name:
139152
endpoint_name = endpoint_fx.__name__
140153

141-
if self.check_ptype is True:
154+
if self.check_prototype is True:
142155

143156
@self.app.post(urljoin("/", endpoint_name), name=endpoint_name)
144157
async def custom_endpoint(
@@ -179,7 +192,7 @@ def run(self, port: int = 8000, host: str = "127.0.0.1", **kw):
179192
>>> X, y = vt.get_mock_data()
180193
>>> model = vt.get_mock_model().fit(X, y)
181194
>>> v = vt.VetiverModel(model = model, model_name = "model", prototype_data = X)
182-
>>> v_api = vt.VetiverAPI(model = v, check_ptype = True)
195+
>>> v_api = vt.VetiverAPI(model = v, check_prototype = True)
183196
>>> v_api.run() # doctest: +SKIP
184197
"""
185198
_jupyter_nb()

vetiver/tests/test_add_endpoint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def sum_dict(x):
3434

3535

3636
@pytest.fixture
37-
def vetiver_client(vetiver_model): # With check_ptype=True
37+
def vetiver_client(vetiver_model): # With check_prototype=True
3838

39-
app = VetiverAPI(vetiver_model, check_ptype=True)
39+
app = VetiverAPI(vetiver_model, check_prototype=True)
4040
app.vetiver_post(sum_values, "sum")
4141

4242
app.app.root_path = "/sum"
@@ -46,9 +46,9 @@ def vetiver_client(vetiver_model): # With check_ptype=True
4646

4747

4848
@pytest.fixture
49-
def vetiver_client_check_ptype_false(vetiver_model): # With check_ptype=False
49+
def vetiver_client_check_ptype_false(vetiver_model): # With check_prototype=False
5050

51-
app = VetiverAPI(vetiver_model, check_ptype=False)
51+
app = VetiverAPI(vetiver_model, check_prototype=False)
5252
app.vetiver_post(sum_dict, "sum")
5353

5454
app.app.root_path = "/sum"

vetiver/tests/test_predict.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def vetiver_model():
2727

2828
@pytest.fixture
2929
def vetiver_client(vetiver_model): # With check_ptype=True
30-
app = VetiverAPI(vetiver_model, check_ptype=True)
30+
app = VetiverAPI(vetiver_model, check_prototype=True)
3131
app.app.root_path = "/predict"
3232
client = TestClient(app.app)
3333

@@ -36,7 +36,7 @@ def vetiver_client(vetiver_model): # With check_ptype=True
3636

3737
@pytest.fixture
3838
def vetiver_client_check_ptype_false(vetiver_model): # With check_ptype=False
39-
app = VetiverAPI(vetiver_model, check_ptype=False)
39+
app = VetiverAPI(vetiver_model, check_prototype=False)
4040
app.app.root_path = "/predict"
4141
client = TestClient(app.app)
4242

@@ -90,7 +90,7 @@ def test_predict_sklearn_type_error(data, vetiver_client):
9090

9191
def test_predict_server_error(vetiver_model):
9292
X, y = mock.get_mock_data()
93-
app = VetiverAPI(vetiver_model, check_ptype=True)
93+
app = VetiverAPI(vetiver_model, check_prototype=True)
9494
app.app.root_path = "/i_do_not_exists"
9595
client = TestClient(app.app)
9696

vetiver/tests/test_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_torch_predict_no_ptype_batch():
102102
torch.manual_seed(3)
103103
x_train, torch_model = _build_torch_v()
104104
v = VetiverModel(torch_model, model_name="torch")
105-
v_api = VetiverAPI(v, check_ptype=False)
105+
v_api = VetiverAPI(v, check_prototype=False)
106106

107107
client = TestClient(v_api.app)
108108
data = [[3.3], [3.3]]
@@ -117,7 +117,7 @@ def test_torch_predict_no_ptype():
117117
torch.manual_seed(3)
118118
x_train, torch_model = _build_torch_v()
119119
v = VetiverModel(torch_model, model_name="torch")
120-
v_api = VetiverAPI(v, check_ptype=False)
120+
v_api = VetiverAPI(v, check_prototype=False)
121121

122122
client = TestClient(v_api.app)
123123
data = [[3.3]]

vetiver/tests/test_sklearn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def _start_application(save_ptype: bool = True):
1515
description="A regression model for testing purposes",
1616
)
1717

18-
app = VetiverAPI(v, check_ptype=save_ptype)
18+
app = VetiverAPI(v, check_prototype=save_ptype)
1919

2020
return app
2121

vetiver/tests/test_statsmodels.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,17 @@ def sm_model():
2424

2525

2626
@pytest.fixture
27-
def vetiver_client(sm_model): # With check_ptype=True
28-
app = vetiver.VetiverAPI(sm_model, check_ptype=True)
27+
def vetiver_client(sm_model): # With check_prototype=True
28+
app = vetiver.VetiverAPI(sm_model, check_prototype=True)
2929
app.app.root_path = "/predict"
3030
client = TestClient(app.app)
3131

3232
return client
3333

3434

3535
@pytest.fixture
36-
def vetiver_client_check_ptype_false(sm_model): # With check_ptype=True
37-
app = vetiver.VetiverAPI(sm_model, check_ptype=False)
36+
def vetiver_client_check_ptype_false(sm_model): # With check_prototype=True
37+
app = vetiver.VetiverAPI(sm_model, check_prototype=False)
3838
app.app.root_path = "/predict"
3939
client = TestClient(app.app)
4040

0 commit comments

Comments
 (0)