Skip to content

Commit b8902b6

Browse files
has2k1isabelizimm
authored andcommitted
Remove save_ptype as a parameter and as a property
save_ptype is still used in the test helper functions, but it is not part of the API anymore.
1 parent e0878f9 commit b8902b6

16 files changed

Lines changed: 37 additions & 82 deletions

examples/coffeeratings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
lr_fit = LinearRegression().fit(X_train, y_train)
1818

1919
# create vetiver model
20-
v = vetiver.VetiverModel(lr_fit, save_ptype = True, ptype_data=X_train, model_name = "v")
20+
v = vetiver.VetiverModel(lr_fit, ptype_data=X_train, model_name = "v")
2121

2222
# version model via pin
2323
from pins import board_folder

vetiver/handlers/_interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
except ImportError:
88
torch_exists = False
99

10-
def create_translator(model, ptype_data, save_ptype):
10+
def create_translator(model, ptype_data):
1111
"""check for model type to handle prediction
1212
1313
Parameters
@@ -22,10 +22,10 @@ def create_translator(model, ptype_data, save_ptype):
2222
"""
2323
if torch_exists:
2424
if isinstance(model, torch.nn.Module):
25-
return pytorch_vt.TorchHandler(model, ptype_data, save_ptype)
25+
return pytorch_vt.TorchHandler(model, ptype_data)
2626

2727
if isinstance(model, sklearn.base.BaseEstimator):
28-
return sklearn_vt.SKLearnHandler(model, ptype_data, save_ptype)
28+
return sklearn_vt.SKLearnHandler(model, ptype_data)
2929

3030
else:
3131
raise NotImplementedError

vetiver/handlers/pytorch_vt.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@ class TorchHandler:
1717
model : nn.Module
1818
a trained torch model
1919
"""
20-
def __init__(self, model, ptype_data, save_ptype):
20+
def __init__(self, model, ptype_data):
2121
self.model = model
2222
self.ptype_data = ptype_data
23-
self.save_ptype = save_ptype
2423

2524
def create_description(self):
2625
"""Create description for torch model
@@ -48,14 +47,13 @@ def ptype(self):
4847
----------
4948
ptype_data : pd.DataFrame, np.ndarray, or None
5049
Training data to create ptype
51-
save_ptype : bool
5250
5351
Returns
5452
-------
5553
ptype : pd.DataFrame or None
5654
Zero-row DataFrame for storing data types
5755
"""
58-
ptype = vetiver_create_ptype(self.ptype_data, self.save_ptype)
56+
ptype = vetiver_create_ptype(self.ptype_data)
5957

6058
return ptype
6159

vetiver/handlers/sklearn_vt.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@ class SKLearnHandler:
1212
a trained sklearn model
1313
"""
1414

15-
def __init__(self, model, ptype_data, save_ptype):
15+
def __init__(self, model, ptype_data):
1616
self.model = model
1717
self.ptype_data = ptype_data
18-
self.save_ptype = save_ptype
1918

2019
def create_description(self):
2120
"""Create description for sklearn model
@@ -42,14 +41,13 @@ def ptype(self):
4241
----------
4342
ptype_data : pd.DataFrame, np.ndarray, or None
4443
Training data to create ptype
45-
save_ptype : bool
4644
4745
Returns
4846
-------
4947
ptype : pd.DataFrame or None
5048
Zero-row DataFrame for storing data types
5149
"""
52-
ptype = vetiver_create_ptype(self.ptype_data, self.save_ptype)
50+
ptype = vetiver_create_ptype(self.ptype_data)
5351
return ptype
5452

5553
def handler_startup():

vetiver/pin_read_write.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool=True):
2929
type = "joblib",
3030
description = model.description,
3131
metadata = {"required_pkgs": model.metadata.get("required_pkgs"),
32-
"save_ptype": model.save_ptype,
3332
"ptype": None if model.ptype == None else model.ptype().json()},
3433
versioned=versioned
3534
)
@@ -79,7 +78,6 @@ def vetiver_pin_read(board, name: str, version: str = None) -> VetiverModel:
7978
url = meta.user.get("url"), # None all the time, besides Connect
8079
required_pkgs = meta.user.get("required_pkgs")
8180
),
82-
save_ptype=meta.user.get("save_ptype"),
8381
ptype_data = json.loads(meta.user.get("ptype")) if meta.user.get("ptype") else None,
8482
versioned = True
8583
)

vetiver/ptype.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,12 @@ def __init__(
2222

2323
class InvalidPTypeError(Exception):
2424
"""
25-
Throw an error if `save_ptype` is not
26-
True, False, or data.frame
25+
Throw an error if ptype cannot be recognised
2726
"""
2827

2928
def __init__(
3029
self,
31-
message="The `ptype_data` argument must be a pandas.DataFrame, a pydantic BaseModel, np.ndarray, or `save_ptype` must be FALSE.",
30+
message="`ptype_data` must be a pd.DataFrame, a pydantic BaseModel or np.ndarray",
3231
):
3332
self.message = message
3433
super().__init__(self.message)
@@ -55,7 +54,7 @@ def _(data: {_data_type}):
5554
"""
5655

5756
@singledispatch
58-
def vetiver_create_ptype(data, save_ptype):
57+
def vetiver_create_ptype(data):
5958
"""Create zero row structure to save data types
6059
6160
Parameters
@@ -69,20 +68,20 @@ def vetiver_create_ptype(data, save_ptype):
6968
Data prototype
7069
7170
"""
72-
msg = CREATE_PTYPE_TPL.format(_data_type=type(data))
73-
msg = ""
74-
raise InvalidPTypeError(message=msg)
71+
raise InvalidPTypeError(
72+
message=CREATE_PTYPE_TPL.format(_data_type=type(data))
73+
)
7574

7675

7776
@vetiver_create_ptype.register
78-
def _vetiver_create_ptype(data: pd.DataFrame, save_ptype):
77+
def _vetiver_create_ptype(data: pd.DataFrame):
7978
dict_data = data.iloc[1, :].to_dict()
8079
ptype = create_model("ptype", **dict_data)
8180
return ptype
8281

8382

8483
@vetiver_create_ptype.register
85-
def _vetiver_create_ptype(data: np.ndarray, save_ptype):
84+
def _vetiver_create_ptype(data: np.ndarray):
8685
dict_data = dict(enumerate(data[1], 0))
8786
# pydantic requires strings as indicies
8887
dict_data = {f"{key}": value.item() for key, value in dict_data.items()}
@@ -91,15 +90,15 @@ def _vetiver_create_ptype(data: np.ndarray, save_ptype):
9190

9291

9392
@vetiver_create_ptype.register
94-
def _vetiver_create_ptype(data: dict, save_ptype):
93+
def _vetiver_create_ptype(data: dict):
9594
return create_model("ptype", **data)
9695

9796

9897
@vetiver_create_ptype.register
99-
def _vetiver_create_ptype(data: BaseModel, save_ptype):
98+
def _vetiver_create_ptype(data: BaseModel):
10099
return data
101100

102101

103102
@vetiver_create_ptype.register
104-
def _vetiver_create_ptype(data: NoneType, save_ptype):
103+
def _vetiver_create_ptype(data: NoneType):
105104
return None

vetiver/tests/test_add_endpoint.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ def _start_application(check_ptype):
99
model = mock.get_mock_model().fit(X, y)
1010
v = VetiverModel(
1111
model=model,
12-
save_ptype=True,
1312
ptype_data=X,
1413
model_name="my_model",
1514
versioned=None,
@@ -43,4 +42,4 @@ def test_endpoint_adds_no_ptype():
4342
data = [0,0,0]
4443
response = client.post("/sum/", json=data)
4544
assert response.status_code == 200, response.text
46-
assert response.json() == {"sum": 0}, response.json()
45+
assert response.json() == {"sum": 0}, response.json()

vetiver/tests/test_build_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ def _build_v():
99
model = mock.get_mock_model().fit(X, y)
1010
v = VetiverModel(
1111
model=model,
12-
save_ptype=True,
1312
ptype_data=X,
1413
model_name="my_model",
1514
versioned=None,

vetiver/tests/test_build_vetiver_model.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def test_vetiver_model_array_ptype():
1616
# build VetiverModel, no ptype
1717
vt1 = VetiverModel(
1818
model=model,
19-
save_ptype=True,
2019
ptype_data=X_array,
2120
model_name="iris",
2221
versioned=None,
@@ -32,7 +31,6 @@ def test_vetiver_model_df_ptype():
3231
# build VetiverModel, df ptype_data
3332
vt2 = VetiverModel(
3433
model=model,
35-
save_ptype=True,
3634
ptype_data=X_df,
3735
model_name="iris",
3836
versioned=None,
@@ -48,7 +46,6 @@ def test_vetiver_model_dict_ptype():
4846
dict_data = {"B": 0, "C": 0, "D": 0}
4947
vt3 = VetiverModel(
5048
model=model,
51-
save_ptype=True,
5249
ptype_data=dict_data,
5350
model_name="iris",
5451
versioned=None,
@@ -64,8 +61,7 @@ def test_vetiver_model_no_ptype():
6461
# build VetiverModel, no ptype
6562
vt4 = VetiverModel(
6663
model=model,
67-
save_ptype=False,
68-
ptype_data=X_df,
64+
ptype_data=None,
6965
model_name="iris",
7066
versioned=None,
7167
description=None,
@@ -74,17 +70,3 @@ def test_vetiver_model_no_ptype():
7470

7571
assert vt4.model == model
7672
assert vt4.ptype == None
77-
78-
79-
def test_vetiver_model_error():
80-
with pytest.raises(AttributeError):
81-
VetiverModel(
82-
model=model,
83-
save_ptype=True,
84-
ptype_data=None,
85-
model_name="iris",
86-
versioned=None,
87-
description=None,
88-
metadata=None,
89-
)
90-

vetiver/tests/test_no_handler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ def test_not_implemented_error():
88
with pytest.raises(NotImplementedError):
99
VetiverModel(
1010
model=y,
11-
save_ptype=True,
1211
ptype_data=X,
1312
model_name="my_model",
1413
versioned=None,

0 commit comments

Comments
 (0)