Skip to content

Commit 35a709d

Browse files
committed
more scalable language
1 parent 2edf364 commit 35a709d

2 files changed

Lines changed: 13 additions & 11 deletions

File tree

vetiver/handlers/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class InvalidModelError(Exception):
1313

1414
def __init__(
1515
self,
16-
message="The `model` argument must be a scikit-learn or torch model.",
16+
message="The `model` argument must be a supported or custom type.",
1717
):
1818
self.message = message
1919
super().__init__(self.message)
@@ -47,9 +47,9 @@ def create_handler(model, ptype_data):
4747
"""
4848

4949
raise InvalidModelError(
50-
"Model must be an sklearn or torch model, or a \
51-
custom handler must be used. See the docs for more info on custom handlers. \
52-
https://rstudio.github.io/vetiver-python/advancedusage/custom_handler.html"
50+
"Model must be a supported model type, or a "
51+
"custom handler must be used. See the docs for more info on custom handlers and"
52+
"supported types https://rstudio.github.io/vetiver-python/"
5353
)
5454

5555

@@ -88,13 +88,13 @@ def create_meta(
8888
url: str = None,
8989
required_pkgs: list = [],
9090
):
91-
"""Create metadata for sklearn model"""
91+
"""Create metadata for a model"""
9292
meta = _model_meta(user, version, url, required_pkgs)
9393

9494
return meta
9595

9696
def construct_ptype(self):
97-
"""Create data prototype for torch model
97+
"""Create data prototype for a model
9898
9999
Parameters
100100
----------

vetiver/handlers/statsmodels.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def describe(self):
2929
desc = f"Statsmodels {self.model.__class__} model."
3030
return desc
3131

32-
def construct_meta(
32+
def create_meta(
3333
user: list = None,
3434
version: str = None,
3535
url: str = None,
@@ -58,10 +58,12 @@ def handler_predict(self, input_data, check_ptype):
5858
prediction
5959
Prediction from model
6060
"""
61-
62-
if isinstance(input_data, (list, pd.DataFrame)):
63-
prediction = self.model.predict(input_data)
61+
if sm_exists:
62+
if isinstance(input_data, (list, pd.DataFrame)):
63+
prediction = self.model.predict(input_data)
64+
else:
65+
prediction = self.model.predict([input_data])
6466
else:
65-
prediction = self.model.predict([input_data])
67+
raise ImportError("Cannot import `statsmodels`")
6668

6769
return prediction

0 commit comments

Comments
 (0)