Skip to content

Commit 005ff97

Browse files
committed
test serialization
1 parent 2eb1c43 commit 005ff97

3 files changed

Lines changed: 34 additions & 118 deletions

File tree

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,6 @@ dev =
4848

4949
torch =
5050
torch
51+
52+
statsmodels =
53+
statsmodels

vetiver/handlers/statsmodels.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ class StatsmodelsHandler(BaseHandler):
1919
a trained sklearn model
2020
"""
2121

22-
model_class = staticmethod(
23-
lambda: statsmodels.regression.linear_model.RegressionResultsWrapper
24-
)
22+
model_class = staticmethod(lambda: statsmodels.base.wrapper.ResultsWrapper)
2523

2624
def __init__(self, model, ptype_data):
2725
super().__init__(model, ptype_data)

vetiver/tests/test_statsmodels.py

Lines changed: 30 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -1,128 +1,43 @@
1-
# import pytest
1+
import pytest
22

3-
# sm = pytest.importorskip("statsmodels", reason="statsmodels library not installed")
3+
sm = pytest.importorskip("statsmodels.api", reason="statsmodels library not installed")
44

5-
# import numpy as np # noqa
6-
# from fastapi.testclient import TestClient # noqa
5+
statsmodels = pytest.importorskip(
6+
"statsmodels", reason="statsmodels library not installed"
7+
)
78

8-
# from vetiver.vetiver_model import VetiverModel # noqa
9-
# from vetiver import VetiverAPI # noqa
9+
import numpy as np # noqa
10+
from fastapi.testclient import TestClient # noqa
1011

12+
import vetiver # noqa
1113

12-
# def _build_sm():
1314

14-
# input_size = 1
15-
# output_size = 1
15+
@pytest.fixture
16+
def build_sm():
1617

17-
# x_train = np.array(
18-
# [
19-
# [3.3],
20-
# [4.4],
21-
# [5.5],
22-
# [6.71],
23-
# [6.93],
24-
# [4.168],
25-
# [9.779],
26-
# [6.182],
27-
# [7.59],
28-
# [2.167],
29-
# [7.042],
30-
# [10.791],
31-
# [5.313],
32-
# [7.997],
33-
# [3.1],
34-
# ],
35-
# dtype=np.float32,
36-
# )
18+
X, y = vetiver.get_mock_data()
19+
glm = sm.GLM(y, X).fit()
3720

38-
# torch_model = sm.nn.Linear(input_size, output_size)
39-
# return x_train, torch_model
21+
v = vetiver.VetiverModel(glm, "glm", X)
22+
return v
4023

4124

42-
# def test_vetiver_build():
25+
def test_vetiver_build(build_sm):
26+
api = vetiver.VetiverAPI(build_sm)
27+
client = TestClient(api.app)
28+
data = {"B": 0, "C": 0, "D": 0}
29+
response = client.post("/predict", json=data)
30+
assert response.status_code == 200, response.text
31+
assert response.json() == {"prediction": [0.0]}, response.json()
4332

44-
# x_train, torch_model = _build_sm()
4533

46-
# vt2 = VetiverModel(
47-
# model=torch_model,
48-
# ptype_data=x_train,
49-
# model_name="torch",
50-
# versioned=None,
51-
# description=None,
52-
# metadata=None,
53-
# )
34+
def test_serialize(build_sm):
35+
import pins
5436

55-
# assert vt2.model == torch_model
56-
57-
58-
# def test_sm_predict_ptype():
59-
# torch.manual_seed(3)
60-
# x_train, torch_model = _build_sm()
61-
# v = VetiverModel(torch_model, model_name="torch", ptype_data=x_train)
62-
# v_api = VetiverAPI(v)
63-
64-
# client = TestClient(v_api.app)
65-
# data = {"0": 3.3}
66-
# response = client.post("/predict", json=data)
67-
68-
# assert response.status_code == 200, response.text
69-
# assert response.json() == {"prediction": [-4.060722351074219]}, response.text
70-
71-
72-
# def test_sm_predict_ptype_batch():
73-
74-
# x_train, torch_model = _build_sm()
75-
# v = VetiverModel(torch_model, model_name="torch", ptype_data=x_train)
76-
# v_api = VetiverAPI(v)
77-
78-
# client = TestClient(v_api.app)
79-
# data = [{"0": 3.3}, {"0": 3.3}]
80-
# response = client.post("/predict", json=data)
81-
82-
# assert response.status_code == 200, response.text
83-
# assert response.json() == {
84-
# "prediction": [[-4.060722351074219], [-4.060722351074219]]
85-
# }, response.text
86-
87-
88-
# def test_sm_predict_ptype_error():
89-
90-
# x_train, torch_model = _build_sm()
91-
# v = VetiverModel(torch_model, model_name="torch", ptype_data=x_train)
92-
# v_api = VetiverAPI(v)
93-
94-
# client = TestClient(v_api.app)
95-
# data = {"0": "bad"}
96-
# response = client.post("/predict", json=data)
97-
98-
# assert response.status_code == 422, response.text # value is not a valid float
99-
100-
101-
# def test_sm_predict_no_ptype_batch():
102-
103-
# x_train, torch_model = _build_sm()
104-
# v = VetiverModel(torch_model, model_name="torch")
105-
# v_api = VetiverAPI(v, check_ptype=False)
106-
107-
# client = TestClient(v_api.app)
108-
# data = [[3.3], [3.3]]
109-
# response = client.post("/predict", json=data)
110-
# assert response.status_code == 200, response.text
111-
# assert response.json() == {
112-
# "prediction": [[-4.060722351074219], [-4.060722351074219]]
113-
# }, response.text
114-
115-
116-
# def test_sm_predict_no_ptype():
117-
118-
# x_train, torch_model = _build_sm()
119-
# v = VetiverModel(torch_model, model_name="torch")
120-
# v_api = VetiverAPI(v, check_ptype=False)
121-
122-
# client = TestClient(v_api.app)
123-
# data = [[3.3]]
124-
# response = client.post("/predict", json=data)
125-
# assert response.status_code == 200, response.text
126-
# assert response.json() == {"prediction": [[-4.060722351074219]]}, response.text
127-
128-
# def test_pin_sm():
37+
board = pins.board_temp(allow_pickle_read=True)
38+
vetiver.vetiver_pin_write(board=board, model=build_sm)
39+
assert isinstance(
40+
board.pin_read("glm"),
41+
statsmodels.genmod.generalized_linear_model.GLMResultsWrapper,
42+
)
43+
board.pin_delete("glm")

0 commit comments

Comments
 (0)