Skip to content

Commit de2c222

Browse files
committed
comments from review
1 parent 4e69722 commit de2c222

5 files changed

Lines changed: 45 additions & 36 deletions

File tree

vetiver/handlers/base.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ def __init__(
2020

2121

2222
@singledispatch
23-
def create_handler(model, ptype_data):
23+
def create_handler(model, prototype_data):
2424
"""check for model type to handle prediction
2525
2626
Parameters
2727
----------
2828
model: object
2929
Description of parameter `x`.
30-
ptype_data : object
30+
prototype_data : object
3131
An object with information (data) whose layout is to be determined.
3232
3333
Returns
@@ -63,7 +63,7 @@ class BaseHandler:
6363
----------
6464
model :
6565
a trained model
66-
ptype_data :
66+
prototype_data :
6767
An object with information (data) whose layout is to be determined.
6868
"""
6969

@@ -73,9 +73,9 @@ def __init_subclass__(cls, **kwargs):
7373
with suppress(AttributeError, NameError):
7474
create_handler.register(cls.model_class(), cls)
7575

76-
def __init__(self, model, ptype_data):
76+
def __init__(self, model, prototype_data):
7777
self.model = model
78-
self.ptype_data = ptype_data
78+
self.prototype_data = prototype_data
7979

8080
def describe(self):
8181
"""Create description for model"""
@@ -93,21 +93,21 @@ def create_meta(
9393

9494
return meta
9595

96-
def construct_ptype(self):
96+
def construct_prototype(self):
9797
"""Create data prototype for a model
9898
9999
Parameters
100100
----------
101-
ptype_data : pd.DataFrame, np.ndarray, or None
102-
Training data to create ptype
101+
prototype_data : pd.DataFrame, np.ndarray, or None
102+
Training data to create prototype
103103
104104
Returns
105105
-------
106-
ptype : pd.DataFrame or None
106+
prototype : pd.DataFrame or None
107107
Zero-row DataFrame for storing data types
108108
"""
109-
ptype = vetiver_create_ptype(self.ptype_data)
110-
return ptype
109+
prototype = vetiver_create_ptype(self.prototype_data)
110+
return prototype
111111

112112
def handler_startup():
113113
"""Include required packages for prediction
@@ -117,7 +117,7 @@ def handler_startup():
117117
"""
118118
...
119119

120-
def handler_predict(self, input_data, check_ptype):
120+
def handler_predict(self, input_data, check_prototype):
121121
"""Generates method for /predict endpoint in VetiverAPI
122122
123123
The `handler_predict` function executes at each API call. Use this
@@ -128,8 +128,8 @@ def handler_predict(self, input_data, check_ptype):
128128
----------
129129
input_data:
130130
Data used to generate prediction
131-
check_ptype:
132-
If type should be checked against `ptype` or not
131+
check_prototype:
132+
If type should be checked against `prototype` or not
133133
134134
Returns
135135
-------
@@ -144,8 +144,8 @@ def handler_predict(self, input_data, check_ptype):
144144

145145

146146
@create_handler.register
147-
def _(model: base.BaseHandler, ptype_data):
148-
if model.ptype_data is None and ptype_data is not None:
149-
model.ptype_data = ptype_data
147+
def _(model: base.BaseHandler, prototype_data):
148+
if model.prototype_data is None and prototype_data is not None:
149+
model.prototype_data = prototype_data
150150

151151
return model

vetiver/handlers/torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def handler_predict(self, input_data, check_prototype):
5858
if not torch_exists:
5959
raise ImportError("Cannot import `torch`.")
6060
if check_prototype:
61-
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
61+
input_data = np.array(input_data, dtype=np.array(self.prototype_data).dtype)
6262
prediction = self.model(torch.from_numpy(input_data))
6363

6464
# do not check ptype

vetiver/server.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,22 @@ def __init__(
4343
self,
4444
model: VetiverModel,
4545
check_prototype: bool = True,
46-
check_ptype: bool = None,
4746
app_factory=FastAPI,
47+
**kwargs,
4848
) -> None:
4949
self.model = model
5050
self.app_factory = app_factory
5151
self.app = app_factory()
5252

53-
if check_ptype is not None:
54-
check_prototype = check_ptype
53+
if "check_ptype" in kwargs:
54+
check_prototype = kwargs.pop("check_ptype")
5555
warn(
5656
"argument for checking input data prototype has changed to "
5757
"check_prototype, from check_ptype",
5858
DeprecationWarning,
5959
stacklevel=2,
6060
)
61+
6162
self.check_prototype = check_prototype
6263

6364
self._init_app()
@@ -151,9 +152,6 @@ def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
151152
if not endpoint_name:
152153
endpoint_name = endpoint_fx.__name__
153154

154-
if hasattr(self.model, "ptype"):
155-
self.model.prototype = self.model.ptype
156-
157155
if self.check_prototype is True:
158156

159157
@self.app.post(urljoin("/", endpoint_name), name=endpoint_name)

vetiver/tests/test_build_vetiver_model.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
model = get_mock_model().fit(X_df, y)
1414

1515

16-
def test_vetiver_model_array_ptype():
17-
# build VetiverModel, no ptype
16+
def test_vetiver_model_array_prototype():
1817
vt1 = vt.VetiverModel(
1918
model=model,
2019
prototype_data=X_array,
@@ -29,8 +28,7 @@ def test_vetiver_model_array_ptype():
2928
assert list(vt1.prototype.__fields__.values())[0].type_ == int
3029

3130

32-
def test_vetiver_model_df_ptype():
33-
# build VetiverModel, df ptype_data
31+
def test_vetiver_model_df_prototype():
3432
vt2 = vt.VetiverModel(
3533
model=model,
3634
prototype_data=X_df,
@@ -45,7 +43,7 @@ def test_vetiver_model_df_ptype():
4543
assert list(vt2.prototype.__fields__.values())[0].type_ == int
4644

4745

48-
def test_vetiver_model_dict_ptype():
46+
def test_vetiver_model_dict_prototype():
4947
dict_data = {"B": 0, "C": 0, "D": 0}
5048
vt3 = vt.VetiverModel(
5149
model=model,
@@ -61,8 +59,7 @@ def test_vetiver_model_dict_ptype():
6159
assert list(vt3.prototype.__fields__.values())[0].type_ == int
6260

6361

64-
def test_vetiver_model_no_ptype():
65-
# build VetiverModel, no ptype
62+
def test_vetiver_model_no_prototype():
6663
vt4 = vt.VetiverModel(
6764
model=model,
6865
prototype_data=None,
@@ -76,6 +73,21 @@ def test_vetiver_model_no_ptype():
7673
assert vt4.prototype is None
7774

7875

76+
def test_vetiver_model_use_ptype():
77+
vt5 = vt.VetiverModel(
78+
model=model,
79+
ptype_data=X_df,
80+
model_name="model",
81+
versioned=None,
82+
description=None,
83+
metadata=None,
84+
)
85+
86+
assert vt5.model == model
87+
assert isinstance(vt5.prototype.construct(), pydantic.BaseModel)
88+
assert list(vt5.prototype.__fields__.values())[0].type_ == int
89+
90+
7991
def test_vetiver_model_from_pin():
8092

8193
v = vt.VetiverModel(

vetiver/vetiver_model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,25 +66,24 @@ def __init__(
6666
model,
6767
model_name: str,
6868
prototype_data=None,
69-
ptype_data=None,
7069
versioned=None,
7170
description: str = None,
7271
metadata: dict = None,
7372
**kwargs
7473
):
75-
if ptype_data is not None:
76-
prototype_data = ptype_data
74+
if "ptype_data" in kwargs:
75+
prototype_data = kwargs.pop("ptype_data")
7776
warn(
7877
"argument for saving input data prototype has changed to "
79-
"save_prototype, from save_ptype",
78+
"prototype_data, from ptype_data",
8079
DeprecationWarning,
8180
stacklevel=2,
8281
)
8382

8483
translator = create_handler(model, prototype_data)
8584

8685
self.model = translator.model
87-
self.prototype = translator.construct_ptype()
86+
self.prototype = translator.construct_prototype()
8887
self.model_name = model_name
8988
self.description = description if description else translator.describe()
9089
self.versioned = versioned

0 commit comments

Comments
 (0)