Skip to content

Commit 612d957

Browse files
authored
Merge pull request #138 from rstudio/prototype
update `ptype_data` to `prototype_data`
2 parents fea51a5 + 6403b57 commit 612d957

30 files changed

Lines changed: 280 additions & 184 deletions

.github/workflows/tests.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,12 @@ jobs:
7878
python-version: 3.8
7979
- name: Install dependencies
8080
run: |
81-
python -m pip install --upgrade pip
82-
python -m pip install -e ".[dev]"
83-
- name: Run Docker
81+
python -m pip install ".[dev]"
82+
python -m pip install --upgrade git+https://github.com/rstudio/vetiver-python@${{ github.sha }}
83+
- name: run Docker
8484
run: |
8585
python script/setup-docker/docker.py
86+
pip freeze > vetiver_requirements.txt
8687
docker build -t mock .
8788
docker run -d -v $PWD/pinsboard:/vetiver/pinsboard -p 8080:8080 mock
8889
sleep 5

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ from vetiver import mock, VetiverModel
4646
X, y = mock.get_mock_data()
4747
model = mock.get_mock_model().fit(X, y)
4848

49-
v = VetiverModel(model, save_ptype=True, ptype_data=X, model_name='mock_model')
49+
v = VetiverModel(model, model_name='mock_model', prototype_data=X)
5050
```
5151

5252
You can **version** and **share** your `VetiverModel()` by choosing a [pins](https://rstudio.github.io/pins-python/) "board" for it, including a local folder, RStudio Connect, Amazon S3, and more.
@@ -63,7 +63,7 @@ You can **deploy** your pinned `VetiverModel()` using `VetiverAPI()`, an extensi
6363

6464
```python
6565
from vetiver import VetiverAPI
66-
app = VetiverAPI(v, check_ptype = True)
66+
app = VetiverAPI(v, check_prototype = True)
6767
```
6868

6969
To start a server using this object, use `app.run(port = 8080)` or your port of choice.

vetiver/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Change to import.metadata when minimum python>=3.8
33
from importlib_metadata import version as _version
44

5-
from .ptype import * # noqa
5+
from .prototype import * # noqa
66
from .vetiver_model import VetiverModel # noqa
77
from .server import VetiverAPI, vetiver_endpoint, predict # noqa
88
from .mock import get_mock_data, get_mock_model # noqa
@@ -19,6 +19,7 @@
1919
from .rsconnect import deploy_rsconnect # noqa
2020
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa
2121
from .model_card import model_card # noqa
22+
from .types import create_prototype, Prototype # noqa
2223

2324
__author__ = "Isabel Zimmerman <isabel.zimmerman@rstudio.com>"
2425
__all__ = []

vetiver/handlers/base.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from functools import singledispatch
33
from contextlib import suppress
44

5-
from ..ptype import vetiver_create_ptype
5+
from ..prototype import vetiver_create_prototype
66
from ..meta import _model_meta
77

88

@@ -20,14 +20,14 @@ def __init__(
2020

2121

2222
@singledispatch
23-
def create_handler(model, ptype_data):
23+
def create_handler(model, prototype_data):
2424
"""check for model type to handle prediction
2525
2626
Parameters
2727
----------
2828
model: object
2929
Description of parameter `x`.
30-
ptype_data : object
30+
prototype_data : object
3131
An object with information (data) whose layout is to be determined.
3232
3333
Returns
@@ -63,7 +63,7 @@ class BaseHandler:
6363
----------
6464
model :
6565
a trained model
66-
ptype_data :
66+
prototype_data :
6767
An object with information (data) whose layout is to be determined.
6868
"""
6969

@@ -73,9 +73,9 @@ def __init_subclass__(cls, **kwargs):
7373
with suppress(AttributeError, NameError):
7474
create_handler.register(cls.model_class(), cls)
7575

76-
def __init__(self, model, ptype_data):
76+
def __init__(self, model, prototype_data):
7777
self.model = model
78-
self.ptype_data = ptype_data
78+
self.prototype_data = prototype_data
7979

8080
def describe(self):
8181
"""Create description for model"""
@@ -93,21 +93,21 @@ def create_meta(
9393

9494
return meta
9595

96-
def construct_ptype(self):
96+
def construct_prototype(self):
9797
"""Create data prototype for a model
9898
9999
Parameters
100100
----------
101-
ptype_data : pd.DataFrame, np.ndarray, or None
102-
Training data to create ptype
101+
prototype_data : pd.DataFrame, np.ndarray, or None
102+
Training data to create prototype
103103
104104
Returns
105105
-------
106-
ptype : pd.DataFrame or None
106+
prototype : pd.DataFrame or None
107107
Zero-row DataFrame for storing data types
108108
"""
109-
ptype = vetiver_create_ptype(self.ptype_data)
110-
return ptype
109+
prototype = vetiver_create_prototype(self.prototype_data)
110+
return prototype
111111

112112
def handler_startup():
113113
"""Include required packages for prediction
@@ -117,7 +117,7 @@ def handler_startup():
117117
"""
118118
...
119119

120-
def handler_predict(self, input_data, check_ptype):
120+
def handler_predict(self, input_data, check_prototype):
121121
"""Generates method for /predict endpoint in VetiverAPI
122122
123123
The `handler_predict` function executes at each API call. Use this
@@ -128,8 +128,8 @@ def handler_predict(self, input_data, check_ptype):
128128
----------
129129
input_data:
130130
Data used to generate prediction
131-
check_ptype:
132-
If type should be checked against `ptype` or not
131+
check_prototype:
132+
If type should be checked against `prototype` or not
133133
134134
Returns
135135
-------
@@ -144,8 +144,8 @@ def handler_predict(self, input_data, check_ptype):
144144

145145

146146
@create_handler.register
147-
def _(model: base.BaseHandler, ptype_data):
148-
if model.ptype_data is None and ptype_data is not None:
149-
model.ptype_data = ptype_data
147+
def _(model: base.BaseHandler, prototype_data):
148+
if model.prototype_data is None and prototype_data is not None:
149+
model.prototype_data = prototype_data
150150

151151
return model

vetiver/handlers/sklearn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def create_meta(
3333

3434
return meta
3535

36-
def handler_predict(self, input_data, check_ptype):
36+
def handler_predict(self, input_data, check_prototype):
3737
"""Generates method for /predict endpoint in VetiverAPI
3838
3939
The `handler_predict` function executes at each API call. Use this
@@ -51,7 +51,7 @@ def handler_predict(self, input_data, check_ptype):
5151
Prediction from model
5252
"""
5353

54-
if not check_ptype or isinstance(input_data, pd.DataFrame):
54+
if not check_prototype or isinstance(input_data, pd.DataFrame):
5555
prediction = self.model.predict(input_data)
5656
else:
5757
prediction = self.model.predict([input_data])

vetiver/handlers/statsmodels.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ class StatsmodelsHandler(BaseHandler):
2121

2222
model_class = staticmethod(lambda: statsmodels.base.wrapper.ResultsWrapper)
2323

24-
def __init__(self, model, ptype_data):
25-
super().__init__(model, ptype_data)
26-
2724
def describe(self):
2825
"""Create description for statsmodels model"""
2926
desc = f"Statsmodels {self.model.__class__} model."
@@ -41,7 +38,7 @@ def create_meta(
4138

4239
return meta
4340

44-
def handler_predict(self, input_data, check_ptype):
41+
def handler_predict(self, input_data, check_prototype):
4542
"""Generates method for /predict endpoint in VetiverAPI
4643
4744
The `handler_predict` function executes at each API call. Use this

vetiver/handlers/torch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def create_meta(
3838

3939
return meta
4040

41-
def handler_predict(self, input_data, check_ptype):
41+
def handler_predict(self, input_data, check_prototype):
4242
"""Generates method for /predict endpoint in VetiverAPI
4343
4444
The `handler_predict` function executes at each API call. Use this
@@ -57,8 +57,8 @@ def handler_predict(self, input_data, check_ptype):
5757
"""
5858
if not torch_exists:
5959
raise ImportError("Cannot import `torch`.")
60-
if check_ptype:
61-
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
60+
if check_prototype:
61+
input_data = np.array(input_data, dtype=np.array(self.prototype_data).dtype)
6262
prediction = self.model(torch.from_numpy(input_data))
6363

6464
# do not check ptype

vetiver/handlers/xgboost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def create_meta(
3838

3939
return meta
4040

41-
def handler_predict(self, input_data, check_ptype):
41+
def handler_predict(self, input_data, check_prototype):
4242
"""Generates method for /predict endpoint in VetiverAPI
4343
4444
The `handler_predict` function executes at each API call. Use this

vetiver/pin_read_write.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True):
3838
>>> model_board = board_temp(versioned = True, allow_pickle_read = True)
3939
>>> X, y = vetiver.get_mock_data()
4040
>>> model = vetiver.get_mock_model().fit(X, y)
41-
>>> v = vetiver.VetiverModel(model = model, model_name = "my_model", ptype_data = X)
41+
>>> v = vetiver.VetiverModel(model, "my_model", prototype_data = X)
4242
>>> vetiver.vetiver_pin_write(model_board, v)
4343
"""
4444
if not board.allow_pickle_read:
@@ -51,14 +51,18 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True):
5151
"with vetiver.model_card()",
5252
)
5353

54+
# convert older model's ptype to prototype
55+
if hasattr(model, "ptype"):
56+
model.prototype = model.ptype
57+
5458
board.pin_write(
5559
model.model,
5660
name=model.model_name,
5761
type="joblib",
5862
description=model.description,
5963
metadata={
6064
"required_pkgs": model.metadata.get("required_pkgs"),
61-
"ptype": None if model.ptype is None else model.ptype().json(),
65+
"prototype": None if model.prototype is None else model.prototype().json(),
6266
},
6367
versioned=versioned,
6468
)

0 commit comments

Comments
 (0)