Skip to content

Commit 0dbcf40

Browse files
committed
parameterize sklearn tests
1 parent c2b7f8b commit 0dbcf40

1 file changed

Lines changed: 16 additions & 25 deletions

File tree

vetiver/tests/test_sklearn.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from vetiver import mock, VetiverModel, VetiverAPI
99
from vetiver.server import predict
1010

11+
np.random.seed(500)
12+
X, y = mock.get_mock_data()
13+
1114

1215
@pytest.fixture
1316
def vetiver_model():
1417
np.random.seed(500)
15-
X, y = mock.get_mock_data()
1618
model = mock.get_mock_model().fit(X, y)
1719
v = VetiverModel(
1820
model=model,
@@ -43,41 +45,30 @@ def vetiver_client_check_ptype_false(vetiver_model): # With check_ptype=False
4345
return client
4446

4547

46-
def test_predict_sklearn_dict_ptype(vetiver_client):
47-
data = {"B": 0, "C": 0, "D": 0}
48+
@pytest.mark.parametrize(
49+
"data,length",
50+
[({"B": 0, "C": 0, "D": 0}, 1), (pd.Series(data=[0, 0, 0]), 1), (X, 100)],
51+
)
52+
def test_predict_sklearn_ptype(data, length, vetiver_client):
4853

4954
response = predict(endpoint=vetiver_client, data=data)
5055

5156
assert isinstance(response, pd.DataFrame), response
5257
assert response.iloc[0, 0] == 44.47
53-
assert len(response) == 1
58+
assert len(response) == length
5459

5560

56-
def test_predict_sklearn_no_ptype(vetiver_client_check_ptype_false):
61+
@pytest.mark.parametrize(
62+
"data,length",
63+
[({"B": 0, "C": 0, "D": 0}, 1), (pd.Series(data=[0, 0, 0]), 1), (X, 100)],
64+
)
65+
def test_predict_sklearn_no_ptype(data, length, vetiver_client_check_ptype_false):
5766
X, y = mock.get_mock_data()
58-
response = predict(endpoint=vetiver_client_check_ptype_false, data=X)
59-
60-
assert isinstance(response, pd.DataFrame), response
61-
assert response.iloc[0, 0] == 44.47
62-
assert len(response) == 100
63-
64-
65-
def test_predict_sklearn_df_check_ptype(vetiver_client):
66-
X, y = mock.get_mock_data()
67-
response = predict(endpoint=vetiver_client, data=X)
68-
69-
assert isinstance(response, pd.DataFrame), response
70-
assert response.iloc[0, 0] == 44.47
71-
assert len(response) == 100
72-
73-
74-
def test_predict_sklearn_series_check_ptype(vetiver_client):
75-
ser = pd.Series(data=[0, 0, 0])
76-
response = predict(endpoint=vetiver_client, data=ser)
67+
response = predict(endpoint=vetiver_client_check_ptype_false, data=data)
7768

7869
assert isinstance(response, pd.DataFrame), response
7970
assert response.iloc[0, 0] == 44.47
80-
assert len(response) == 1
71+
assert len(response) == length
8172

8273

8374
@pytest.mark.parametrize("data", [(0), 0, 0.0, "0"])

0 commit comments

Comments
 (0)