Skip to content

Commit 5e98300

Browse files
committed
github actions tweaks
1 parent aebf597 commit 5e98300

4 files changed

Lines changed: 22 additions & 7 deletions

File tree

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
- name: Install dependencies
3333
run: |
3434
python -m pip install --upgrade pip
35-
python -m pip install -e .[dev,torch]
35+
python -m pip install -e .[dev,torch,statsmodels]
3636
- name: Run Tests
3737
run: |
3838
pytest -m 'not rsc_test' --cov --cov-report xml

vetiver/handlers/statsmodels.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def handler_predict(self, input_data, check_ptype):
5858
prediction
5959
Prediction from model
6060
"""
61-
61+
if not check_ptype:
62+
input_data = pd.DataFrame(input_data)
6263
if isinstance(input_data, pd.DataFrame):
6364
prediction = self.model.predict(input_data)
6465
else:

vetiver/server.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ async def prediction(
8282

8383
@app.post("/predict")
8484
async def prediction(input_data: Request):
85-
86-
y = await input_data.json()
85+
y = await input_data.body()
8786
prediction = self.model.handler_predict(y, check_ptype=self.check_ptype)
8887

8988
return {"prediction": prediction.tolist()}

vetiver/tests/test_statsmodels.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,24 @@ def test_vetiver_build(build_sm):
2626
api = vetiver.VetiverAPI(build_sm)
2727
client = TestClient(api.app)
2828
data = {"B": 0, "C": 0, "D": 0}
29-
response = client.post("/predict", json=data)
30-
assert response.status_code == 200, response.text
31-
assert response.json() == {"prediction": [0.0]}, response.json()
29+
30+
response = vetiver.predict(endpoint=client, data=data)
31+
32+
assert response.iloc[0, 0] == 0.0
33+
assert len(response) == 1
34+
35+
36+
def test_no_ptype(build_sm):
37+
v = build_sm
38+
v.ptype = None
39+
api = vetiver.VetiverAPI(v)
40+
client = TestClient(api.app)
41+
data = 0, 0, 0
42+
43+
response = vetiver.predict(endpoint=client, data=data)
44+
45+
assert response.iloc[0, 0] == 0.0
46+
assert len(response) == 1
3247

3348

3449
def test_serialize(build_sm):

0 commit comments

Comments
 (0)