|
8 | 8 | from vetiver import mock, VetiverModel, VetiverAPI |
9 | 9 | from vetiver.server import predict |
10 | 10 |
|
| 11 | +np.random.seed(500) |
| 12 | +X, y = mock.get_mock_data() |
| 13 | + |
11 | 14 |
|
12 | 15 | @pytest.fixture |
13 | 16 | def vetiver_model(): |
14 | 17 | np.random.seed(500) |
15 | | - X, y = mock.get_mock_data() |
16 | 18 | model = mock.get_mock_model().fit(X, y) |
17 | 19 | v = VetiverModel( |
18 | 20 | model=model, |
@@ -43,41 +45,30 @@ def vetiver_client_check_ptype_false(vetiver_model): # With check_ptype=False |
43 | 45 | return client |
44 | 46 |
|
45 | 47 |
|
46 | | -def test_predict_sklearn_dict_ptype(vetiver_client): |
47 | | - data = {"B": 0, "C": 0, "D": 0} |
| 48 | +@pytest.mark.parametrize( |
| 49 | + "data,length", |
| 50 | + [({"B": 0, "C": 0, "D": 0}, 1), (pd.Series(data=[0, 0, 0]), 1), (X, 100)], |
| 51 | +) |
| 52 | +def test_predict_sklearn_ptype(data, length, vetiver_client): |
48 | 53 |
|
49 | 54 | response = predict(endpoint=vetiver_client, data=data) |
50 | 55 |
|
51 | 56 | assert isinstance(response, pd.DataFrame), response |
52 | 57 | assert response.iloc[0, 0] == 44.47 |
53 | | - assert len(response) == 1 |
| 58 | + assert len(response) == length |
54 | 59 |
|
55 | 60 |
|
56 | | -def test_predict_sklearn_no_ptype(vetiver_client_check_ptype_false): |
| 61 | +@pytest.mark.parametrize( |
| 62 | + "data,length", |
| 63 | + [({"B": 0, "C": 0, "D": 0}, 1), (pd.Series(data=[0, 0, 0]), 1), (X, 100)], |
| 64 | +) |
| 65 | +def test_predict_sklearn_no_ptype(data, length, vetiver_client_check_ptype_false): |
57 | 66 | X, y = mock.get_mock_data() |
58 | | - response = predict(endpoint=vetiver_client_check_ptype_false, data=X) |
59 | | - |
60 | | - assert isinstance(response, pd.DataFrame), response |
61 | | - assert response.iloc[0, 0] == 44.47 |
62 | | - assert len(response) == 100 |
63 | | - |
64 | | - |
65 | | -def test_predict_sklearn_df_check_ptype(vetiver_client): |
66 | | - X, y = mock.get_mock_data() |
67 | | - response = predict(endpoint=vetiver_client, data=X) |
68 | | - |
69 | | - assert isinstance(response, pd.DataFrame), response |
70 | | - assert response.iloc[0, 0] == 44.47 |
71 | | - assert len(response) == 100 |
72 | | - |
73 | | - |
74 | | -def test_predict_sklearn_series_check_ptype(vetiver_client): |
75 | | - ser = pd.Series(data=[0, 0, 0]) |
76 | | - response = predict(endpoint=vetiver_client, data=ser) |
| 67 | + response = predict(endpoint=vetiver_client_check_ptype_false, data=data) |
77 | 68 |
|
78 | 69 | assert isinstance(response, pd.DataFrame), response |
79 | 70 | assert response.iloc[0, 0] == 44.47 |
80 | | - assert len(response) == 1 |
| 71 | + assert len(response) == length |
81 | 72 |
|
82 | 73 |
|
83 | 74 | @pytest.mark.parametrize("data", [(0), 0, 0.0, "0"]) |
|
0 commit comments