Skip to content

Commit 6d0e199

Browse files
authored
refactor tests (#209)
* refactor tests * parameterize ptype tests
1 parent 4e372b0 commit 6d0e199

14 files changed

Lines changed: 186 additions & 379 deletions

vetiver/data/__init__.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
1-
from importlib_resources import files as _files
2-
3-
sources = {
4-
"mtcars": _files("vetiver") / "data/mtcars.csv",
5-
"chicago": _files("vetiver") / "data/chicago.csv",
6-
"sacramento": _files("vetiver") / "data/sacramento.csv",
7-
}
1+
__all__ = [
2+
"mtcars",
3+
"chicago",
4+
"sacramento",
5+
]
86

97

108
def __dir__():
11-
return list(sources)
9+
return __all__
1210

1311

14-
def __getattr__(k):
12+
def _load_data_csv(name):
1513
import pandas as pd
14+
import pkg_resources
15+
16+
fname = pkg_resources.resource_filename("vetiver.data", f"{name}.csv")
17+
return pd.read_csv(fname)
18+
19+
20+
def __getattr__(name):
21+
if name not in __all__:
22+
raise AttributeError(f"No dataset named: {name}")
1623

17-
f_path = sources.get("mtcars")
18-
if k == "chicago":
19-
f_path = sources.get("chicago")
20-
elif k == "sacramento":
21-
f_path = sources.get("sacramento")
22-
return pd.read_csv(f_path)
24+
return _load_data_csv(name)

vetiver/server.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import uvicorn
66
import logging
77
import pandas as pd
8-
from fastapi import FastAPI, Request, testclient
8+
from fastapi import FastAPI, Request
99
from fastapi.exceptions import RequestValidationError
1010
from fastapi.openapi.utils import get_openapi
1111
from fastapi.responses import HTMLResponse, RedirectResponse, PlainTextResponse
@@ -321,9 +321,8 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw) -> pd.Da
321321
>>> endpoint = vetiver.vetiver_endpoint(url='http://127.0.0.1:8000/predict')
322322
>>> vetiver.predict(endpoint, X) # doctest: +SKIP
323323
"""
324-
if isinstance(endpoint, testclient.TestClient):
325-
requester = endpoint
326-
endpoint = requester.app.root_path
324+
if "test_client" in kw:
325+
requester = kw.pop("test_client")
327326
else:
328327
requester = requests
329328

vetiver/tests/__init__.py

Whitespace-only changes.

vetiver/tests/conftest.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
from vetiver import VetiverModel, VetiverAPI
3+
from vetiver.helpers import api_data_to_frame
4+
from starlette.testclient import TestClient
5+
6+
7+
def sum_values(x):
8+
return x.sum().to_list()
9+
10+
11+
def sum_values_no_prototype(x):
12+
return api_data_to_frame(x).sum().to_list()
13+
14+
15+
@pytest.fixture
16+
def client(model: VetiverModel) -> TestClient:
17+
app = VetiverAPI(model, check_prototype=True)
18+
app.vetiver_post(sum_values, "sum")
19+
client = TestClient(app.app)
20+
21+
return client
22+
23+
24+
@pytest.fixture
25+
def client_no_prototype(model: VetiverModel) -> TestClient:
26+
app = VetiverAPI(model, check_prototype=False)
27+
app.vetiver_post(sum_values_no_prototype, "sum")
28+
client = TestClient(app.app)
29+
30+
return client

vetiver/tests/test_add_endpoint.py

Lines changed: 15 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,32 @@
11
import pytest
2-
3-
import numpy as np
42
import pandas as pd
5-
from fastapi.testclient import TestClient
6-
7-
from vetiver import mock, VetiverModel, VetiverAPI
8-
from vetiver.helpers import api_data_to_frame
9-
import vetiver
3+
from vetiver import mock, VetiverModel
104

115

12-
@pytest.fixture
13-
def vetiver_model():
14-
np.random.seed(500)
6+
@pytest.fixture()
7+
def model():
158
X, y = mock.get_mock_data()
16-
model = mock.get_mock_model().fit(X, y)
17-
v = VetiverModel(
18-
model=model,
19-
prototype_data=X,
20-
model_name="my_model",
21-
versioned=None,
22-
description="A regression model for testing purposes",
23-
)
24-
25-
return v
26-
27-
28-
def sum_values(x):
29-
return x.sum().to_list()
30-
31-
32-
def sum_values_no_prototype(x):
33-
return api_data_to_frame(x).sum().to_list()
34-
9+
model = mock.get_mock_model()
3510

36-
@pytest.fixture
37-
def vetiver_client(vetiver_model): # With check_prototype=True
38-
39-
app = VetiverAPI(vetiver_model, check_prototype=True)
40-
app.vetiver_post(sum_values, "sum")
41-
42-
app.app.root_path = "/sum"
43-
client = TestClient(app.app)
44-
45-
return client
11+
return VetiverModel(model.fit(X, y), "model", prototype_data=X)
4612

4713

4814
@pytest.fixture
49-
def vetiver_client_check_ptype_false(vetiver_model): # With check_prototype=False
50-
51-
app = VetiverAPI(vetiver_model, check_prototype=False)
52-
app.vetiver_post(sum_values_no_prototype, "sum")
15+
def data() -> pd.DataFrame:
16+
return pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]})
5317

54-
app.app.root_path = "/sum"
55-
client = TestClient(app.app)
5618

57-
return client
58-
59-
60-
def test_endpoint_adds_ptype(vetiver_client):
61-
62-
data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]})
63-
response = vetiver.predict(endpoint=vetiver_client, data=data)
19+
def test_endpoint_adds(client, data):
20+
response = client.post("/sum/", data=data.to_json(orient="records"))
6421

65-
assert isinstance(response, pd.DataFrame)
66-
assert response.to_json() == '{"sum":{"0":3,"1":6,"2":9}}', response.to_json()
22+
assert response.status_code == 200
23+
assert response.json() == {"sum": [3, 6, 9]}
6724

6825

69-
def test_endpoint_adds_no_ptype(vetiver_client_check_ptype_false):
26+
def test_endpoint_adds_no_prototype(client_no_prototype, data):
7027

7128
data = pd.DataFrame({"B": [1, 1, 1], "C": [2, 2, 2], "D": [3, 3, 3]})
72-
response = vetiver.predict(endpoint=vetiver_client_check_ptype_false, data=data)
29+
response = client_no_prototype.post("/sum/", data=data.to_json(orient="records"))
7330

74-
assert isinstance(response, pd.DataFrame)
75-
assert response.to_json() == '{"sum":{"0":3,"1":6,"2":9}}', response.to_json()
31+
assert response.status_code == 200
32+
assert response.json() == {"sum": [3, 6, 9]}

vetiver/tests/test_build_api.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

0 commit comments

Comments
 (0)