Skip to content

Commit 3d2d800

Browse files
committed
updates from review
1 parent abb1e9f commit 3d2d800

5 files changed

Lines changed: 43 additions & 49 deletions

File tree

docs/source/index.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ You can use vetiver with:
55

66
- `scikit-learn <https://scikit-learn.org/stable/>`_
77
- `pytorch <https://pytorch.org/>`_
8+
- `statsmodels <https://www.statsmodels.org/>`_
9+
- `xgboost <https://xgboost.readthedocs.io/>`_
810

911
You can install the released version of vetiver from `PyPI <https://pypi.org/project/vetiver/>`_:
1012

@@ -65,6 +67,18 @@ Monitor
6567
~pin_metrics
6668
~plot_metrics
6769

70+
Model Handlers
71+
==================
72+
.. autosummary::
73+
:toctree: reference/
74+
:caption: Monitor
75+
76+
~BaseHandler
77+
~SKLearnHandler
78+
~TorchHandler
79+
~StatsmodelsHandler
80+
~XGBoostHandler
81+
6882
Advanced Usage
6983
==================
7084
.. toctree::

vetiver/handlers/sklearn.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,12 @@ class SKLearnHandler(BaseHandler):
1616

1717
model_class = staticmethod(lambda: sklearn.base.BaseEstimator)
1818

19-
def __init__(self, model, ptype_data):
20-
super().__init__(model, ptype_data)
21-
2219
def describe(self):
2320
"""Create description for sklearn model"""
2421
desc = f"Scikit-learn {self.model.__class__} model"
2522
return desc
2623

27-
def construct_meta(
24+
def create_meta(
2825
user: list = None,
2926
version: str = None,
3027
url: str = None,
@@ -54,17 +51,9 @@ def handler_predict(self, input_data, check_ptype):
5451
Prediction from model
5552
"""
5653

57-
if check_ptype:
58-
if isinstance(input_data, pd.DataFrame):
59-
prediction = self.model.predict(input_data)
60-
else:
61-
prediction = self.model.predict([input_data])
62-
63-
# do not check ptype
64-
else:
65-
if not isinstance(input_data, list):
66-
input_data = [input_data.split(",")] # user delimiter ?
67-
54+
if not check_ptype or isinstance(input_data, pd.DataFrame):
6855
prediction = self.model.predict(input_data)
56+
else:
57+
prediction = self.model.predict([input_data])
6958

7059
return prediction

vetiver/handlers/statsmodels.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
class StatsmodelsHandler(BaseHandler):
1414
"""Handler class for creating VetiverModels with statsmodels.
1515
16-
Parameters
16+
Methods
1717
----------
1818
model : statsmodels
1919
a trained and fit statsmodels model
@@ -58,12 +58,12 @@ def handler_predict(self, input_data, check_ptype):
5858
prediction
5959
Prediction from model
6060
"""
61-
if sm_exists:
62-
if isinstance(input_data, (list, pd.DataFrame)):
63-
prediction = self.model.predict(input_data)
64-
else:
65-
prediction = self.model.predict([input_data])
66-
else:
61+
if not sm_exists:
6762
raise ImportError("Cannot import `statsmodels`")
6863

64+
if isinstance(input_data, (list, pd.DataFrame)):
65+
prediction = self.model.predict(input_data)
66+
else:
67+
prediction = self.model.predict([input_data])
68+
6969
return prediction

vetiver/handlers/torch.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ class TorchHandler(BaseHandler):
2121

2222
model_class = staticmethod(lambda: torch.nn.Module)
2323

24-
def __init__(self, model, ptype_data):
25-
super().__init__(model, ptype_data)
26-
2724
def describe(self):
2825
"""Create description for torch model"""
2926
desc = f"Pytorch model of type {type(self.model)}"
@@ -58,17 +55,15 @@ def handler_predict(self, input_data, check_ptype):
5855
prediction
5956
Prediction from model
6057
"""
61-
if torch_exists:
62-
if check_ptype:
63-
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
64-
prediction = self.model(torch.from_numpy(input_data))
65-
66-
# do not check ptype
67-
else:
68-
input_data = torch.tensor(input_data)
69-
prediction = self.model(input_data)
58+
if not torch_exists:
59+
raise ImportError("Cannot import `torch`.")
60+
if check_ptype:
61+
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
62+
prediction = self.model(torch.from_numpy(input_data))
7063

64+
# do not check ptype
7165
else:
72-
raise ImportError("Cannot import `torch`.")
66+
input_data = torch.tensor(input_data)
67+
prediction = self.model(input_data)
7368

7469
return prediction

vetiver/handlers/xgboost.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@ class XGBoostHandler(BaseHandler):
2121

2222
model_class = staticmethod(lambda: xgboost.Booster)
2323

24-
def __init__(self, model, ptype_data):
25-
super().__init__(model, ptype_data)
26-
2724
def describe(self):
2825
"""Create description for xgboost model"""
2926
desc = f"XGBoost {self.model.__class__} model."
@@ -59,17 +56,16 @@ def handler_predict(self, input_data, check_ptype):
5956
Prediction from model
6057
"""
6158

62-
if xgb_exists:
63-
if not isinstance(input_data, xgboost.DMatrix):
64-
if not isinstance(input_data, pd.DataFrame):
65-
try:
66-
input_data = pd.DataFrame(input_data)
67-
except ValueError:
68-
raise (f"Expected a dict or DataFrame, got {type(input_data)}")
69-
input_data = xgboost.DMatrix(input_data)
70-
71-
prediction = self.model.predict(input_data)
72-
else:
59+
if not xgb_exists:
7360
raise ImportError("Cannot import `xgboost`")
7461

62+
if not isinstance(input_data, pd.DataFrame):
63+
try:
64+
input_data = pd.DataFrame(input_data)
65+
except ValueError:
66+
raise (f"Expected a dict or DataFrame, got {type(input_data)}")
67+
input_data = xgboost.DMatrix(input_data)
68+
69+
prediction = self.model.predict(input_data)
70+
7571
return prediction

0 commit comments

Comments
 (0)