Skip to content

Commit 2c83fa4

Browse files
committed
refactor metadata w prototypes
1 parent 748624e commit 2c83fa4

6 files changed

Lines changed: 18 additions & 12 deletions

File tree

docs/source/advancedusage/custom_handler.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ class CustomHandler(BaseHandler):
1212
super().__init__(model, ptype_data)
1313

1414
model_type = staticmethod(lambda: newmodeltype)
15-
pkg = sklearn # modeling package
1615
pip_name = "scikit-learn" # pkg name on pip, used for tracking pkg versions
1716

1817
def handler_predict(self, input_data, check_ptype: bool):

script/setup-docker/docker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,4 @@
1313

1414
vetiver.vetiver_pin_write(board, v)
1515

16-
1716
vetiver.prepare_docker(board, "mymodel")

vetiver/handlers/sklearn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class SKLearnHandler(BaseHandler):
1616
model_class = staticmethod(lambda: sklearn.base.BaseEstimator)
1717
pip_name = "scikit-learn"
1818

19-
def handler_predict(self, input_data, check_ptype):
19+
def handler_predict(self, input_data, check_prototype):
2020
"""Generates method for /predict endpoint in VetiverAPI
2121
2222
The `handler_predict` function executes at each API call. Use this

vetiver/pin_read_write.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True):
5555
# convert older model's ptype to prototype
5656
if hasattr(model, "ptype"):
5757
model.prototype = model.ptype
58+
delattr(model, "ptype")
5859
# metadata is dict
5960
if isinstance(model.metadata, dict):
6061
model.metadata = VetiverMeta.from_dict(model.metadata)
@@ -68,7 +69,7 @@ def vetiver_pin_write(board, model: VetiverModel, versioned: bool = True):
6869
"user": model.metadata.user,
6970
"vetiver_meta": {
7071
"required_pkgs": model.metadata.required_pkgs,
71-
"ptype": None if model.prototype is None else model.prototype().json(),
72+
"prototype": None if not model.prototype else model.prototype().json(),
7273
},
7374
},
7475
versioned=versioned,

vetiver/tests/test_build_vetiver_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_vetiver_model_basemodel_prototype():
7575
prototype_data=m,
7676
model_name="model",
7777
versioned=False,
78-
description=None
78+
description=None,
7979
)
8080

8181
assert vt4.model == model
@@ -107,7 +107,7 @@ def test_vetiver_model_use_ptype():
107107
)
108108

109109
assert vt5.model == model
110-
assert vt5.ptype is None
110+
assert vt5.prototype is None
111111
assert vt5.metadata == VetiverMeta(
112112
user={"test": 123},
113113
version=None,
@@ -134,8 +134,8 @@ def test_vetiver_model_from_pin():
134134
assert isinstance(v2, vt.VetiverModel)
135135
assert isinstance(v2.model, sklearn.base.BaseEstimator)
136136
assert isinstance(v2.prototype.construct(), pydantic.BaseModel)
137-
assert v2.metadata.get("user") == {"test": 123}
138-
assert v2.metadata.get("version") is not None
137+
assert v2.metadata.user == {"test": 123}
138+
assert v2.metadata.version is not None
139139
assert v2.metadata.required_pkgs == ["scikit-learn"]
140-
140+
141141
board.pin_delete("model")

vetiver/vetiver_model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,18 @@ def from_pin(cls, board, name: str, version: str = None):
100100
meta = board.pin_meta(name, version)
101101

102102
if "vetiver_meta" in meta.user:
103-
ptype = meta.user.get("vetiver_meta").get("prototype")
104-
required_pkgs = meta.user.get("vetiver_meta").get("required_pkgs")
103+
get_prototype = meta.user.get("vetiver_meta").get("prototype", None)
104+
required_pkgs = meta.user.get("vetiver_meta").get("required_pkgs", None)
105105
meta.user.pop("vetiver_meta")
106106
else:
107-
ptype = meta.user.get("ptype", None)
107+
# ptype = meta.user.get("ptype", None)
108+
109+
get_prototype = meta.user.get("ptype")
110+
# elif meta.user.get("prototype"):
111+
# get_prototype = meta.user.get("prototype")
112+
# else:
113+
# get_prototype = None
114+
108115
required_pkgs = meta.user.get("required_pkgs")
109116

110117
return cls(

0 commit comments

Comments
 (0)