Skip to content

Commit 2eb1c43

Browse files
committed
initial statsmodels support
1 parent 64e84e1 commit 2eb1c43

3 files changed

Lines changed: 206 additions & 0 deletions

File tree

vetiver/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .handlers.base import BaseHandler, create_handler, InvalidModelError # noqa
1515
from .handlers.sklearn import SKLearnHandler # noqa
1616
from .handlers.torch import TorchHandler # noqa
17+
from .handlers.statsmodels import StatsmodelsHandler # noqa
1718
from .rsconnect import deploy_rsconnect # noqa
1819
from .monitor import compute_metrics, pin_metrics, plot_metrics, _rolling_df # noqa
1920

vetiver/handlers/statsmodels.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import pandas as pd
2+
3+
from ..meta import _model_meta
4+
from .base import BaseHandler
5+
6+
sm_exists = True
7+
try:
8+
import statsmodels
9+
except ImportError:
10+
sm_exists = False
11+
12+
13+
class StatsmodelsHandler(BaseHandler):
14+
"""Handler class for creating VetiverModels with sklearn.
15+
16+
Parameters
17+
----------
18+
model : statsmodels
19+
a trained sklearn model
20+
"""
21+
22+
model_class = staticmethod(
23+
lambda: statsmodels.regression.linear_model.RegressionResultsWrapper
24+
)
25+
26+
def __init__(self, model, ptype_data):
27+
super().__init__(model, ptype_data)
28+
29+
def describe(self):
30+
"""Create description for sklearn model"""
31+
desc = f"Statsmodels {self.model.__class__} model."
32+
return desc
33+
34+
def construct_meta(
35+
user: list = None,
36+
version: str = None,
37+
url: str = None,
38+
required_pkgs: list = [],
39+
):
40+
"""Create metadata for sklearn model"""
41+
required_pkgs = required_pkgs + ["statsmodels"]
42+
meta = _model_meta(user, version, url, required_pkgs)
43+
44+
return meta
45+
46+
def handler_predict(self, input_data, check_ptype):
47+
"""Generates method for /predict endpoint in VetiverAPI
48+
49+
The `handler_predict` function executes at each API call. Use this
50+
function for calling `predict()` and any other tasks that must be executed
51+
at each API call.
52+
53+
Parameters
54+
----------
55+
input_data:
56+
Test data
57+
58+
Returns
59+
-------
60+
prediction
61+
Prediction from model
62+
"""
63+
64+
if check_ptype:
65+
if isinstance(input_data, pd.DataFrame):
66+
prediction = self.model.predict(input_data)
67+
else:
68+
prediction = self.model.predict([input_data])
69+
70+
# do not check ptype
71+
else:
72+
if not isinstance(input_data, list):
73+
input_data = [input_data.split(",")] # user delimiter ?
74+
75+
prediction = self.model.predict(input_data)
76+
77+
return prediction

vetiver/tests/test_statsmodels.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# import pytest
2+
3+
# sm = pytest.importorskip("statsmodels", reason="statsmodels library not installed")
4+
5+
# import numpy as np # noqa
6+
# from fastapi.testclient import TestClient # noqa
7+
8+
# from vetiver.vetiver_model import VetiverModel # noqa
9+
# from vetiver import VetiverAPI # noqa
10+
11+
12+
# def _build_sm():
13+
14+
# input_size = 1
15+
# output_size = 1
16+
17+
# x_train = np.array(
18+
# [
19+
# [3.3],
20+
# [4.4],
21+
# [5.5],
22+
# [6.71],
23+
# [6.93],
24+
# [4.168],
25+
# [9.779],
26+
# [6.182],
27+
# [7.59],
28+
# [2.167],
29+
# [7.042],
30+
# [10.791],
31+
# [5.313],
32+
# [7.997],
33+
# [3.1],
34+
# ],
35+
# dtype=np.float32,
36+
# )
37+
38+
# torch_model = sm.nn.Linear(input_size, output_size)
39+
# return x_train, torch_model
40+
41+
42+
# def test_vetiver_build():
43+
44+
# x_train, torch_model = _build_sm()
45+
46+
# vt2 = VetiverModel(
47+
# model=torch_model,
48+
# ptype_data=x_train,
49+
# model_name="torch",
50+
# versioned=None,
51+
# description=None,
52+
# metadata=None,
53+
# )
54+
55+
# assert vt2.model == torch_model
56+
57+
58+
# def test_sm_predict_ptype():
59+
# torch.manual_seed(3)
60+
# x_train, torch_model = _build_sm()
61+
# v = VetiverModel(torch_model, model_name="torch", ptype_data=x_train)
62+
# v_api = VetiverAPI(v)
63+
64+
# client = TestClient(v_api.app)
65+
# data = {"0": 3.3}
66+
# response = client.post("/predict", json=data)
67+
68+
# assert response.status_code == 200, response.text
69+
# assert response.json() == {"prediction": [-4.060722351074219]}, response.text
70+
71+
72+
# def test_sm_predict_ptype_batch():
73+
74+
# x_train, torch_model = _build_sm()
75+
# v = VetiverModel(torch_model, model_name="torch", ptype_data=x_train)
76+
# v_api = VetiverAPI(v)
77+
78+
# client = TestClient(v_api.app)
79+
# data = [{"0": 3.3}, {"0": 3.3}]
80+
# response = client.post("/predict", json=data)
81+
82+
# assert response.status_code == 200, response.text
83+
# assert response.json() == {
84+
# "prediction": [[-4.060722351074219], [-4.060722351074219]]
85+
# }, response.text
86+
87+
88+
# def test_sm_predict_ptype_error():
89+
90+
# x_train, torch_model = _build_sm()
91+
# v = VetiverModel(torch_model, model_name="torch", ptype_data=x_train)
92+
# v_api = VetiverAPI(v)
93+
94+
# client = TestClient(v_api.app)
95+
# data = {"0": "bad"}
96+
# response = client.post("/predict", json=data)
97+
98+
# assert response.status_code == 422, response.text # value is not a valid float
99+
100+
101+
# def test_sm_predict_no_ptype_batch():
102+
103+
# x_train, torch_model = _build_sm()
104+
# v = VetiverModel(torch_model, model_name="torch")
105+
# v_api = VetiverAPI(v, check_ptype=False)
106+
107+
# client = TestClient(v_api.app)
108+
# data = [[3.3], [3.3]]
109+
# response = client.post("/predict", json=data)
110+
# assert response.status_code == 200, response.text
111+
# assert response.json() == {
112+
# "prediction": [[-4.060722351074219], [-4.060722351074219]]
113+
# }, response.text
114+
115+
116+
# def test_sm_predict_no_ptype():
117+
118+
# x_train, torch_model = _build_sm()
119+
# v = VetiverModel(torch_model, model_name="torch")
120+
# v_api = VetiverAPI(v, check_ptype=False)
121+
122+
# client = TestClient(v_api.app)
123+
# data = [[3.3]]
124+
# response = client.post("/predict", json=data)
125+
# assert response.status_code == 200, response.text
126+
# assert response.json() == {"prediction": [[-4.060722351074219]]}, response.text
127+
128+
# def test_pin_sm():

0 commit comments

Comments
 (0)