Skip to content

Commit 22405bd

Browse files
authored
Merge pull request #100 from rstudio/statsmodels
implement statsmodels handler
2 parents afe610e + 2cd6976 commit 22405bd

7 files changed

Lines changed: 148 additions & 8 deletions

File tree

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
- name: Install dependencies
3333
run: |
3434
python -m pip install --upgrade pip
35-
python -m pip install -e .[dev,torch]
35+
python -m pip install -e .[dev,torch,statsmodels]
3636
- name: Run Tests
3737
run: |
3838
pytest -m 'not rsc_test' --cov --cov-report xml

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,6 @@ dev =
4848

4949
torch =
5050
torch
51+
52+
statsmodels =
53+
statsmodels

vetiver/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .handlers.base import BaseHandler, create_handler, InvalidModelError # noqa
1515
from .handlers.sklearn import SKLearnHandler # noqa
1616
from .handlers.torch import TorchHandler # noqa
17+
from .handlers.statsmodels import StatsmodelsHandler # noqa
1718
from .rsconnect import deploy_rsconnect # noqa
1819
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa
1920

vetiver/handlers/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class InvalidModelError(Exception):
1313

1414
def __init__(
1515
self,
16-
message="The `model` argument must be a scikit-learn or torch model.",
16+
message="The `model` argument must be a supported or custom type.",
1717
):
1818
self.message = message
1919
super().__init__(self.message)
@@ -47,9 +47,9 @@ def create_handler(model, ptype_data):
4747
"""
4848

4949
raise InvalidModelError(
50-
"Model must be an sklearn or torch model, or a \
51-
custom handler must be used. See the docs for more info on custom handlers. \
52-
https://rstudio.github.io/vetiver-python/advancedusage/custom_handler.html"
50+
"Model must be a supported model type, or a "
51+
"custom handler must be used. See the docs for more info on custom handlers and"
52+
"supported types https://rstudio.github.io/vetiver-python/"
5353
)
5454

5555

@@ -88,13 +88,13 @@ def create_meta(
8888
url: str = None,
8989
required_pkgs: list = [],
9090
):
91-
"""Create metadata for sklearn model"""
91+
"""Create metadata for a model"""
9292
meta = _model_meta(user, version, url, required_pkgs)
9393

9494
return meta
9595

9696
def construct_ptype(self):
97-
"""Create data prototype for torch model
97+
"""Create data prototype for a model
9898
9999
Parameters
100100
----------

vetiver/handlers/statsmodels.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import pandas as pd
2+
3+
from ..meta import _model_meta
4+
from .base import BaseHandler
5+
6+
sm_exists = True
7+
try:
8+
import statsmodels.api
9+
except ImportError:
10+
sm_exists = False
11+
12+
13+
class StatsmodelsHandler(BaseHandler):
14+
"""Handler class for creating VetiverModels with statsmodels.
15+
16+
Parameters
17+
----------
18+
model : statsmodels
19+
a trained and fit statsmodels model
20+
"""
21+
22+
model_class = staticmethod(lambda: statsmodels.base.wrapper.ResultsWrapper)
23+
24+
def __init__(self, model, ptype_data):
25+
super().__init__(model, ptype_data)
26+
27+
def describe(self):
28+
"""Create description for statsmodels model"""
29+
desc = f"Statsmodels {self.model.__class__} model."
30+
return desc
31+
32+
def create_meta(
33+
user: list = None,
34+
version: str = None,
35+
url: str = None,
36+
required_pkgs: list = [],
37+
):
38+
"""Create metadata for statsmodel"""
39+
required_pkgs = required_pkgs + ["statsmodels"]
40+
meta = _model_meta(user, version, url, required_pkgs)
41+
42+
return meta
43+
44+
def handler_predict(self, input_data, check_ptype):
45+
"""Generates method for /predict endpoint in VetiverAPI
46+
47+
The `handler_predict` function executes at each API call. Use this
48+
function for calling `predict()` and any other tasks that must be executed
49+
at each API call.
50+
51+
Parameters
52+
----------
53+
input_data:
54+
Test data
55+
56+
Returns
57+
-------
58+
prediction
59+
Prediction from model
60+
"""
61+
if sm_exists:
62+
if isinstance(input_data, (list, pd.DataFrame)):
63+
prediction = self.model.predict(input_data)
64+
else:
65+
prediction = self.model.predict([input_data])
66+
else:
67+
raise ImportError("Cannot import `statsmodels`")
68+
69+
return prediction

vetiver/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ async def prediction(
8282

8383
@app.post("/predict")
8484
async def prediction(input_data: Request):
85-
8685
y = await input_data.json()
86+
8787
prediction = self.model.handler_predict(y, check_ptype=self.check_ptype)
8888

8989
return {"prediction": prediction.tolist()}

vetiver/tests/test_statsmodels.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import pytest
2+
3+
sm = pytest.importorskip("statsmodels.api", reason="statsmodels library not installed")
4+
5+
statsmodels = pytest.importorskip(
6+
"statsmodels", reason="statsmodels library not installed"
7+
)
8+
9+
import numpy as np # noqa
10+
import pandas as pd # noqa
11+
from fastapi.testclient import TestClient # noqa
12+
13+
import vetiver # noqa
14+
15+
16+
@pytest.fixture
17+
def build_sm():
18+
19+
X, y = vetiver.get_mock_data()
20+
glm = sm.GLM(y, X).fit()
21+
22+
v = vetiver.VetiverModel(glm, "glm", X)
23+
return v
24+
25+
26+
def test_vetiver_build(build_sm):
27+
api = vetiver.VetiverAPI(build_sm)
28+
client = TestClient(api.app)
29+
data = [{"B": 0, "C": 0, "D": 0}]
30+
31+
response = vetiver.predict(endpoint=client, data=data)
32+
33+
assert response.iloc[0, 0] == 0.0
34+
assert len(response) == 1
35+
36+
37+
def test_batch(build_sm):
38+
api = vetiver.VetiverAPI(build_sm)
39+
client = TestClient(api.app)
40+
data = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD"))
41+
42+
response = vetiver.predict(endpoint=client, data=data)
43+
44+
assert len(response) == 100
45+
46+
47+
def test_no_ptype(build_sm):
48+
api = vetiver.VetiverAPI(build_sm, check_ptype=False)
49+
client = TestClient(api.app)
50+
data = [0, 0, 0]
51+
52+
response = vetiver.predict(endpoint=client, data=data)
53+
54+
assert response.iloc[0, 0] == 0.0
55+
assert len(response) == 1
56+
57+
58+
def test_serialize(build_sm):
59+
import pins
60+
61+
board = pins.board_temp(allow_pickle_read=True)
62+
vetiver.vetiver_pin_write(board=board, model=build_sm)
63+
assert isinstance(
64+
board.pin_read("glm"),
65+
statsmodels.genmod.generalized_linear_model.GLMResultsWrapper,
66+
)
67+
board.pin_delete("glm")

0 commit comments

Comments
 (0)