Skip to content

Commit 8a35e0f

Browse files
committed
handle dict, df inputs
1 parent 8a9ad41 commit 8a35e0f

3 files changed

Lines changed: 54 additions & 28 deletions

File tree

vetiver/handlers/xgboost.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ def handler_predict(self, input_data, check_ptype):
6161

6262
if xgb_exists:
6363
if not isinstance(input_data, xgboost.DMatrix):
64-
if isinstance(input_data, pd.DataFrame):
65-
input_data = xgboost.DMatrix(input_data)
66-
else:
67-
input_data = xgboost.DMatrix(
68-
input_data, label=self.model.feature_names
69-
)
64+
if not isinstance(input_data, pd.DataFrame):
65+
try:
66+
input_data = pd.DataFrame(input_data)
67+
except ValueError:
68+
raise (f"Expected a dict or DataFrame, got {type(input_data)}")
69+
input_data = xgboost.DMatrix(input_data)
7070

7171
prediction = self.model.predict(input_data)
7272
else:

vetiver/server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw):
227227
elif isinstance(data, dict):
228228
response = requester.post(endpoint, json=data, **kw)
229229
else:
230+
# TODO: Check in on JSON serialization of DMatrix for XGBoost
230231
try:
231232
response = requester.post(endpoint, json=data, **kw)
232233
except TypeError:

vetiver/tests/test_xgboost.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,44 +4,69 @@
44

55
from vetiver.data import mtcars # noqa
66
from vetiver.handlers.xgboost import XGBoostHandler # noqa
7+
import numpy as np # noqa
8+
from fastapi.testclient import TestClient # noqa
9+
10+
import vetiver # noqa
711

812

913
@pytest.fixture
10-
def fit():
14+
def build_xgb():
1115
# read in data
1216
dtrain = xgb.DMatrix(mtcars.drop(columns="mpg"), label=mtcars["mpg"])
1317
# specify parameters via map
14-
param = {"max_depth": 2, "eta": 1, "objective": "reg:squarederror"}
18+
param = {
19+
"max_depth": 2,
20+
"eta": 1,
21+
"objective": "reg:squarederror",
22+
"random_state": 0,
23+
}
1524
num_round = 2
1625
fit = xgb.train(param, dtrain, num_round)
1726

18-
return fit
27+
return vetiver.VetiverModel(fit, "xgb", mtcars.drop(columns="mpg"))
1928

2029

21-
@pytest.fixture
22-
def handler(fit):
23-
return XGBoostHandler(fit, None)
30+
def test_vetiver_build(build_xgb):
31+
api = vetiver.VetiverAPI(build_xgb)
32+
client = TestClient(api.app)
33+
data = mtcars.head(1).drop(columns="mpg")
34+
35+
response = vetiver.predict(endpoint=client, data=data)
36+
37+
assert response.iloc[0, 0] == 21.064373016357422
38+
assert len(response) == 1
39+
40+
41+
def test_batch(build_xgb):
42+
api = vetiver.VetiverAPI(build_xgb)
43+
client = TestClient(api.app)
44+
data = mtcars.head(3).drop(columns="mpg")
45+
46+
response = vetiver.predict(endpoint=client, data=data)
2447

48+
assert response.iloc[0, 0] == 21.064373016357422
49+
assert len(response) == 3
2550

26-
def test_handler_xgboost_predict_dmatrix(handler):
27-
dtest = xgb.DMatrix(mtcars.drop(columns="mpg"))
28-
handler.handler_predict(dtest, True)
2951

52+
def test_no_ptype(build_xgb):
53+
api = vetiver.VetiverAPI(build_xgb, check_ptype=False)
54+
client = TestClient(api.app)
55+
data = mtcars.head(1).drop(columns="mpg")
3056

31-
def test_handler_xgboost_predict_df(handler):
32-
dtest = mtcars.drop(columns="mpg")
33-
handler.handler_predict(dtest, True)
57+
response = vetiver.predict(endpoint=client, data=data)
3458

59+
assert response.iloc[0, 0] == 21.064373016357422
60+
assert len(response) == 1
3561

36-
@pytest.mark.xfail
37-
def test_handler_xgboost_predict_str(handler):
38-
# TODO: prediction from a string")
39-
dtest = mtcars.drop(columns="mpg")
40-
handler.handler_predict(dtest, False)
4162

63+
def test_serialize(build_xgb):
64+
import pins
4265

43-
@pytest.mark.xfail
44-
def test_handler_xgboost_predict_list(handler):
45-
# TODO: prediction from a serialized JSON list
46-
row = mtcars.drop(columns="mpg").iloc[0, :].tolist()
47-
handler.handler_predict([row], False)
66+
board = pins.board_temp(allow_pickle_read=True)
67+
vetiver.vetiver_pin_write(board=board, model=build_xgb)
68+
assert isinstance(
69+
board.pin_read("xgb"),
70+
xgb.Booster,
71+
)
72+
board.pin_delete("xgb")

0 commit comments

Comments
 (0)