|
1 | | -# import pytest |
| 1 | +import pytest |
2 | 2 |
|
3 | | -# sm = pytest.importorskip("statsmodels", reason="statsmodels library not installed") |
| 3 | +sm = pytest.importorskip("statsmodels.api", reason="statsmodels library not installed") |
4 | 4 |
|
5 | | -# import numpy as np # noqa |
6 | | -# from fastapi.testclient import TestClient # noqa |
| 5 | +statsmodels = pytest.importorskip( |
| 6 | + "statsmodels", reason="statsmodels library not installed" |
| 7 | +) |
7 | 8 |
|
8 | | -# from vetiver.vetiver_model import VetiverModel # noqa |
9 | | -# from vetiver import VetiverAPI # noqa |
| 9 | +import numpy as np # noqa |
| 10 | +from fastapi.testclient import TestClient # noqa |
10 | 11 |
|
| 12 | +import vetiver # noqa |
11 | 13 |
|
12 | | -# def _build_sm(): |
13 | 14 |
|
14 | | -# input_size = 1 |
15 | | -# output_size = 1 |
| 15 | +@pytest.fixture |
| 16 | +def build_sm(): |
16 | 17 |
|
17 | | -# x_train = np.array( |
18 | | -# [ |
19 | | -# [3.3], |
20 | | -# [4.4], |
21 | | -# [5.5], |
22 | | -# [6.71], |
23 | | -# [6.93], |
24 | | -# [4.168], |
25 | | -# [9.779], |
26 | | -# [6.182], |
27 | | -# [7.59], |
28 | | -# [2.167], |
29 | | -# [7.042], |
30 | | -# [10.791], |
31 | | -# [5.313], |
32 | | -# [7.997], |
33 | | -# [3.1], |
34 | | -# ], |
35 | | -# dtype=np.float32, |
36 | | -# ) |
| 18 | + X, y = vetiver.get_mock_data() |
| 19 | + glm = sm.GLM(y, X).fit() |
37 | 20 |
|
38 | | -# torch_model = sm.nn.Linear(input_size, output_size) |
39 | | -# return x_train, torch_model |
| 21 | + v = vetiver.VetiverModel(glm, "glm", X) |
| 22 | + return v |
40 | 23 |
|
41 | 24 |
|
42 | | -# def test_vetiver_build(): |
| 25 | +def test_vetiver_build(build_sm): |
| 26 | + api = vetiver.VetiverAPI(build_sm) |
| 27 | + client = TestClient(api.app) |
| 28 | + data = {"B": 0, "C": 0, "D": 0} |
| 29 | + response = client.post("/predict", json=data) |
| 30 | + assert response.status_code == 200, response.text |
| 31 | + assert response.json() == {"prediction": [0.0]}, response.json() |
43 | 32 |
|
44 | | -# x_train, torch_model = _build_sm() |
45 | 33 |
|
46 | | -# vt2 = VetiverModel( |
47 | | -# model=torch_model, |
48 | | -# ptype_data=x_train, |
49 | | -# model_name="torch", |
50 | | -# versioned=None, |
51 | | -# description=None, |
52 | | -# metadata=None, |
53 | | -# ) |
| 34 | +def test_serialize(build_sm): |
| 35 | + import pins |
54 | 36 |
|
55 | | -# assert vt2.model == torch_model |
56 | | - |
57 | | - |
58 | | -# def test_sm_predict_ptype(): |
59 | | -# torch.manual_seed(3) |
60 | | -# x_train, torch_model = _build_sm() |
61 | | -# v = VetiverModel(torch_model, model_name="torch", ptype_data=x_train) |
62 | | -# v_api = VetiverAPI(v) |
63 | | - |
64 | | -# client = TestClient(v_api.app) |
65 | | -# data = {"0": 3.3} |
66 | | -# response = client.post("/predict", json=data) |
67 | | - |
68 | | -# assert response.status_code == 200, response.text |
69 | | -# assert response.json() == {"prediction": [-4.060722351074219]}, response.text |
70 | | - |
71 | | - |
72 | | -# def test_sm_predict_ptype_batch(): |
73 | | - |
74 | | -# x_train, torch_model = _build_sm() |
75 | | -# v = VetiverModel(torch_model, model_name="torch", ptype_data=x_train) |
76 | | -# v_api = VetiverAPI(v) |
77 | | - |
78 | | -# client = TestClient(v_api.app) |
79 | | -# data = [{"0": 3.3}, {"0": 3.3}] |
80 | | -# response = client.post("/predict", json=data) |
81 | | - |
82 | | -# assert response.status_code == 200, response.text |
83 | | -# assert response.json() == { |
84 | | -# "prediction": [[-4.060722351074219], [-4.060722351074219]] |
85 | | -# }, response.text |
86 | | - |
87 | | - |
88 | | -# def test_sm_predict_ptype_error(): |
89 | | - |
90 | | -# x_train, torch_model = _build_sm() |
91 | | -# v = VetiverModel(torch_model, model_name="torch", ptype_data=x_train) |
92 | | -# v_api = VetiverAPI(v) |
93 | | - |
94 | | -# client = TestClient(v_api.app) |
95 | | -# data = {"0": "bad"} |
96 | | -# response = client.post("/predict", json=data) |
97 | | - |
98 | | -# assert response.status_code == 422, response.text # value is not a valid float |
99 | | - |
100 | | - |
101 | | -# def test_sm_predict_no_ptype_batch(): |
102 | | - |
103 | | -# x_train, torch_model = _build_sm() |
104 | | -# v = VetiverModel(torch_model, model_name="torch") |
105 | | -# v_api = VetiverAPI(v, check_ptype=False) |
106 | | - |
107 | | -# client = TestClient(v_api.app) |
108 | | -# data = [[3.3], [3.3]] |
109 | | -# response = client.post("/predict", json=data) |
110 | | -# assert response.status_code == 200, response.text |
111 | | -# assert response.json() == { |
112 | | -# "prediction": [[-4.060722351074219], [-4.060722351074219]] |
113 | | -# }, response.text |
114 | | - |
115 | | - |
116 | | -# def test_sm_predict_no_ptype(): |
117 | | - |
118 | | -# x_train, torch_model = _build_sm() |
119 | | -# v = VetiverModel(torch_model, model_name="torch") |
120 | | -# v_api = VetiverAPI(v, check_ptype=False) |
121 | | - |
122 | | -# client = TestClient(v_api.app) |
123 | | -# data = [[3.3]] |
124 | | -# response = client.post("/predict", json=data) |
125 | | -# assert response.status_code == 200, response.text |
126 | | -# assert response.json() == {"prediction": [[-4.060722351074219]]}, response.text |
127 | | - |
128 | | -# def test_pin_sm(): |
| 37 | + board = pins.board_temp(allow_pickle_read=True) |
| 38 | + vetiver.vetiver_pin_write(board=board, model=build_sm) |
| 39 | + assert isinstance( |
| 40 | + board.pin_read("glm"), |
| 41 | + statsmodels.genmod.generalized_linear_model.GLMResultsWrapper, |
| 42 | + ) |
| 43 | + board.pin_delete("glm") |
0 commit comments