|
1 | 1 | import pytest |
2 | | - |
3 | | -import numpy as np |
4 | 2 | import pandas as pd |
5 | | -from fastapi.testclient import TestClient |
6 | | - |
7 | | -from vetiver import mock, VetiverModel, VetiverAPI |
8 | | -from vetiver.helpers import api_data_to_frame |
9 | | -import vetiver |
| 3 | +from vetiver import mock, VetiverModel |
10 | 4 |
|
11 | 5 |
|
12 | | -@pytest.fixture |
13 | | -def vetiver_model(): |
14 | | - np.random.seed(500) |
| 6 | +@pytest.fixture() |
| 7 | +def model(): |
15 | 8 | X, y = mock.get_mock_data() |
16 | | - model = mock.get_mock_model().fit(X, y) |
17 | | - v = VetiverModel( |
18 | | - model=model, |
19 | | - prototype_data=X, |
20 | | - model_name="my_model", |
21 | | - versioned=None, |
22 | | - description="A regression model for testing purposes", |
23 | | - ) |
24 | | - |
25 | | - return v |
26 | | - |
27 | | - |
28 | | -def sum_values(x): |
29 | | - return x.sum().to_list() |
30 | | - |
31 | | - |
32 | | -def sum_values_no_prototype(x): |
33 | | - return api_data_to_frame(x).sum().to_list() |
34 | | - |
| 9 | + model = mock.get_mock_model() |
35 | 10 |
|
36 | | -@pytest.fixture |
37 | | -def vetiver_client(vetiver_model): # With check_prototype=True |
38 | | - |
39 | | - app = VetiverAPI(vetiver_model, check_prototype=True) |
40 | | - app.vetiver_post(sum_values, "sum") |
41 | | - |
42 | | - app.app.root_path = "/sum" |
43 | | - client = TestClient(app.app) |
44 | | - |
45 | | - return client |
| 11 | + return VetiverModel(model.fit(X, y), "model", prototype_data=X) |
46 | 12 |
|
47 | 13 |
|
48 | 14 | @pytest.fixture |
49 | | -def vetiver_client_check_ptype_false(vetiver_model): # With check_prototype=False |
50 | | - |
51 | | - app = VetiverAPI(vetiver_model, check_prototype=False) |
52 | | - app.vetiver_post(sum_values_no_prototype, "sum") |
| 15 | +def data() -> pd.DataFrame: |
| 16 | + return pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]}) |
53 | 17 |
|
54 | | - app.app.root_path = "/sum" |
55 | | - client = TestClient(app.app) |
56 | 18 |
|
57 | | - return client |
58 | | - |
59 | | - |
60 | | -def test_endpoint_adds_ptype(vetiver_client): |
61 | | - |
62 | | - data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]}) |
63 | | - response = vetiver.predict(endpoint=vetiver_client, data=data) |
| 19 | +def test_endpoint_adds(client, data): |
| 20 | + response = client.post("/sum/", data=data.to_json(orient="records")) |
64 | 21 |
|
65 | | - assert isinstance(response, pd.DataFrame) |
66 | | - assert response.to_json() == '{"sum":{"0":3,"1":6,"2":9}}', response.to_json() |
| 22 | + assert response.status_code == 200 |
| 23 | + assert response.json() == {"sum": [3, 6, 9]} |
67 | 24 |
|
68 | 25 |
|
69 | | -def test_endpoint_adds_no_ptype(vetiver_client_check_ptype_false): |
| 26 | +def test_endpoint_adds_no_prototype(client_no_prototype, data): |
70 | 27 |
|
71 | 28 | data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]}) |
72 | | - response = vetiver.predict(endpoint=vetiver_client_check_ptype_false, data=data) |
| 29 | + response = client_no_prototype.post("/sum/", data=data.to_json(orient="records")) |
73 | 30 |
|
74 | | - assert isinstance(response, pd.DataFrame) |
75 | | - assert response.to_json() == '{"sum":{"0":3,"1":6,"2":9}}', response.to_json() |
| 31 | + assert response.status_code == 200 |
| 32 | + assert response.json() == {"sum": [3, 6, 9]} |
0 commit comments