Skip to content

Commit c7ec28f

Browse files
isabelizimmmachow
andcommitted
xgboost start
Co-authored-by: Michael Chow <chowmichedu@gmail.com>
1 parent 22405bd commit c7ec28f

4 files changed

Lines changed: 128 additions & 3 deletions

File tree

.github/workflows/tests.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
- name: Install dependencies
3333
run: |
3434
python -m pip install --upgrade pip
35-
python -m pip install -e .[dev,torch,statsmodels]
35+
python -m pip install -e .[dev,torch,statsmodels,xgboost]
3636
- name: Run Tests
3737
run: |
3838
pytest -m 'not rsc_test' --cov --cov-report xml
@@ -65,8 +65,8 @@ jobs:
6565
run: |
6666
pytest vetiver -m 'rsc_test'
6767
68-
test-no-torch:
69-
name: "Test no-torch"
68+
test-no-extras:
69+
name: "Test no exra ml frameworks"
7070
runs-on: ubuntu-latest
7171
steps:
7272
- uses: actions/checkout@v2

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,6 @@ torch =
5151

5252
statsmodels =
5353
statsmodels
54+
55+
xgboost =
56+
xgboost

vetiver/handlers/xgboost.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import pandas as pd
2+
3+
from ..meta import _model_meta
4+
from .base import BaseHandler
5+
6+
xgb_exists = True
7+
try:
8+
import xgboost
9+
except ImportError:
10+
xgb_exists = False
11+
12+
13+
class XGBoostHandler(BaseHandler):
14+
"""Handler class for creating VetiverModels with statsmodels.
15+
16+
Parameters
17+
----------
18+
model : statsmodels
19+
a trained and fit statsmodels model
20+
"""
21+
22+
model_class = staticmethod(lambda: xgboost.Booster)
23+
24+
def __init__(self, model, ptype_data):
25+
super().__init__(model, ptype_data)
26+
27+
def describe(self):
28+
"""Create description for xgboost model"""
29+
desc = f"Statsmodels {self.model.__class__} model."
30+
return desc
31+
32+
def create_meta(
33+
user: list = None,
34+
version: str = None,
35+
url: str = None,
36+
required_pkgs: list = [],
37+
):
38+
"""Create metadata for statsmodel"""
39+
required_pkgs = required_pkgs + ["xgboost"]
40+
meta = _model_meta(user, version, url, required_pkgs)
41+
42+
return meta
43+
44+
def handler_predict(self, input_data, check_ptype):
45+
"""Generates method for /predict endpoint in VetiverAPI
46+
47+
The `handler_predict` function executes at each API call. Use this
48+
function for calling `predict()` and any other tasks that must be executed
49+
at each API call.
50+
51+
Parameters
52+
----------
53+
input_data:
54+
Test data
55+
56+
Returns
57+
-------
58+
prediction
59+
Prediction from model
60+
"""
61+
62+
if xgb_exists:
63+
if not isinstance(input_data, xgboost.DMatrix):
64+
if isinstance(input_data, pd.DataFrame):
65+
input_data = xgboost.DMatrix(input_data)
66+
else:
67+
input_data = xgboost.DMatrix(
68+
input_data, label=self.model.feature_names
69+
)
70+
71+
prediction = self.model.predict(input_data)
72+
else:
73+
raise ImportError("Cannot import `xgboost`")
74+
75+
return prediction

vetiver/tests/test_xgboost.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pytest
2+
3+
xgb = pytest.importorskip("xgboost", reason="xgboost library not installed")
4+
5+
from vetiver.data import mtcars # noqa
6+
from vetiver.handlers.xgboost import XGBoostHandler # noqa
7+
8+
9+
@pytest.fixture
10+
def fit():
11+
# read in data
12+
dtrain = xgb.DMatrix(mtcars.drop(columns="mpg"), label=mtcars["mpg"])
13+
# specify parameters via map
14+
param = {"max_depth": 2, "eta": 1, "objective": "reg:squarederror"}
15+
num_round = 2
16+
fit = xgb.train(param, dtrain, num_round)
17+
18+
return fit
19+
20+
21+
@pytest.fixture
22+
def handler(fit):
23+
return XGBoostHandler(fit, None)
24+
25+
26+
def test_handler_xgboost_predict_dmatrix(handler):
27+
dtest = xgb.DMatrix(mtcars.drop(columns="mpg"))
28+
handler.handler_predict(dtest, True)
29+
30+
31+
def test_handler_xgboost_predict_df(handler):
32+
dtest = mtcars.drop(columns="mpg")
33+
handler.handler_predict(dtest, True)
34+
35+
36+
@pytest.mark.xfail
37+
def test_handler_xgboost_predict_str(handler):
38+
# TODO: prediction from a string")
39+
dtest = mtcars.drop(columns="mpg")
40+
handler.handler_predict(dtest, False)
41+
42+
43+
@pytest.mark.xfail
44+
def test_handler_xgboost_predict_list(handler):
45+
# TODO: prediction from a serialized JSON list
46+
row = mtcars.drop(columns="mpg").iloc[0, :].tolist()
47+
handler.handler_predict([row], False)

0 commit comments

Comments
 (0)