Skip to content

Commit 4e69722

Browse files
committed
updates from review
1 parent 16a9228 commit 4e69722

4 files changed

Lines changed: 17 additions & 14 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ from vetiver import mock, VetiverModel
4343
X, y = mock.get_mock_data()
4444
model = mock.get_mock_model().fit(X, y)
4545

46-
v = VetiverModel(model, save_ptype=True, ptype_data=X, model_name='mock_model')
46+
v = VetiverModel(model, model_name='mock_model', prototype_data=X)
4747
```
4848

4949
You can **version** and **share** your `VetiverModel()` by choosing a [pins](https://rstudio.github.io/pins-python/) "board" for it, including a local folder, RStudio Connect, Amazon S3, and more.
@@ -60,7 +60,7 @@ You can **deploy** your pinned `VetiverModel()` using `VetiverAPI()`, an extensi
6060

6161
```python
6262
from vetiver import VetiverAPI
63-
app = VetiverAPI(v, check_ptype = True)
63+
app = VetiverAPI(v, check_prototype = True)
6464
```
6565

6666
To start a server using this object, use `app.run(port = 8080)` or your port of choice.

vetiver/pin_read_write.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True):
6262
description=model.description,
6363
metadata={
6464
"required_pkgs": model.metadata.get("required_pkgs"),
65-
"ptype": None
66-
if model.prototype is None
67-
else model.prototype().json(), # ptype_change
65+
"prototype": None if model.prototype is None else model.prototype().json(),
6866
},
6967
versioned=versioned,
7068
)

vetiver/tests/test_sklearn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
import pytest
55

66

7-
def _start_application(save_ptype: bool = True):
7+
def _start_application(save_prototype: bool = True):
88
X, y = mock.get_mock_data()
99
model = mock.get_mock_model().fit(X, y)
1010
v = VetiverModel(
1111
model=model,
12-
prototype_data=X if save_ptype else None,
12+
prototype_data=X if save_prototype else None,
1313
model_name="my_model",
1414
versioned=None,
1515
description="A regression model for testing purposes",
1616
)
1717

18-
app = VetiverAPI(v, check_prototype=save_ptype)
18+
app = VetiverAPI(v, check_prototype=save_prototype)
1919

2020
return app
2121

@@ -48,7 +48,7 @@ def test_predict_endpoint_ptype_error():
4848

4949
def test_predict_endpoint_no_ptype():
5050
np.random.seed(500)
51-
client = TestClient(_start_application(save_ptype=False).app)
51+
client = TestClient(_start_application(save_prototype=False).app)
5252
data = [{"B": 0, "C": 0, "D": 0}]
5353
response = client.post("/predict", json=data)
5454
assert response.status_code == 200, response.text
@@ -57,7 +57,7 @@ def test_predict_endpoint_no_ptype():
5757

5858
def test_predict_endpoint_no_ptype_batch():
5959
np.random.seed(500)
60-
client = TestClient(_start_application(save_ptype=False).app)
60+
client = TestClient(_start_application(save_prototype=False).app)
6161
data = [[0, 0, 0], [0, 0, 0]]
6262
response = client.post("/predict", json=data)
6363
assert response.status_code == 200, response.text
@@ -66,7 +66,7 @@ def test_predict_endpoint_no_ptype_batch():
6666

6767
def test_predict_endpoint_no_ptype_error():
6868
np.random.seed(500)
69-
client = TestClient(_start_application(save_ptype=False).app)
69+
client = TestClient(_start_application(save_prototype=False).app)
7070
data = {"hell0", 9, 32.0}
7171
with pytest.raises(TypeError):
7272
client.post("/predict", json=data)

vetiver/vetiver_model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,13 @@ def from_pin(cls, board, name: str, version: str = None):
101101
model = board.pin_read(name, version)
102102
meta = board.pin_meta(name, version)
103103

104+
if meta.user.get("ptype"):
105+
get_prototype = meta.user.get("ptype")
106+
elif meta.user.get("prototype"):
107+
get_prototype = meta.user.get("prototype")
108+
else:
109+
get_prototype = None
110+
104111
return cls(
105112
model=model,
106113
model_name=name,
@@ -111,8 +118,6 @@ def from_pin(cls, board, name: str, version: str = None):
111118
url=meta.local.get("url"), # None all the time, besides Connect
112119
required_pkgs=meta.user.get("required_pkgs"),
113120
),
114-
prototype_data=json.loads(meta.user.get("ptype"))
115-
if meta.user.get("ptype")
116-
else None,
121+
prototype_data=json.loads(get_prototype) if get_prototype else None,
117122
versioned=True,
118123
)

0 commit comments

Comments
 (0)