Skip to content

Commit abb17c0

Browse files
authored
enh: support multiple types of prediction (#233)
* predict_proba support for sklearn * add SklearnPredictionTypes * clean up vetiver_post * lint * update tests * generalize model calls * seed for mock model * move seed * allow approx values for differences in version/arch
1 parent 4ba9969 commit abb17c0

13 files changed

Lines changed: 221 additions & 105 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ repos:
1010
types:
1111
- python
1212
args:
13-
- "--max-line-length=90"
13+
- "--max-line-length=100"
1414
- id: trailing-whitespace
1515
- id: end-of-file-fixer
1616
- id: check-yaml

vetiver/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
) # noqa
1111
from .vetiver_model import VetiverModel # noqa
1212
from .server import VetiverAPI, vetiver_endpoint, predict # noqa
13-
from .mock import get_mock_data, get_mock_model # noqa
13+
from .mock import get_mock_data, get_mock_model, get_mtcars_model # noqa
1414
from .pin_read_write import vetiver_pin_write # noqa
1515
from .attach_pkgs import load_pkgs, get_board_pkgs # noqa
1616
from .meta import VetiverMeta # noqa

vetiver/handlers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def handler_startup():
121121
"""
122122
...
123123

124-
def handler_predict(self, input_data, check_prototype):
124+
def handler_predict(self, input_data, check_prototype, **kw):
125125
"""Generates method for /predict endpoint in VetiverAPI
126126
127127
The `handler_predict` function executes at each API call. Use this

vetiver/handlers/sklearn.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class SKLearnHandler(BaseHandler):
1616
model_class = staticmethod(lambda: sklearn.base.BaseEstimator)
1717
pip_name = "scikit-learn"
1818

19-
def handler_predict(self, input_data, check_prototype):
19+
def handler_predict(self, input_data, check_prototype: bool, **kw):
2020
"""
2121
Generates method for /predict endpoint in VetiverAPI
2222
@@ -28,16 +28,22 @@ def handler_predict(self, input_data, check_prototype):
2828
----------
2929
input_data:
3030
Test data
31+
check_prototype: bool
32+
prediction_type: str
33+
Type of prediction to make. One of "predict", "predict_proba",
34+
or "predict_log_proba". Default is "predict".
3135
3236
Returns
3337
-------
3438
prediction:
3539
Prediction from model
3640
"""
41+
prediction_type = kw.get("prediction_type", "predict")
3742

38-
if not check_prototype or isinstance(input_data, pd.DataFrame):
39-
prediction = self.model.predict(input_data)
40-
else:
41-
prediction = self.model.predict([input_data])
43+
input_data = (
44+
[input_data]
45+
if check_prototype and not isinstance(input_data, pd.DataFrame)
46+
else input_data
47+
)
4248

43-
return prediction.tolist()
49+
return getattr(self.model, prediction_type)(input_data).tolist()

vetiver/handlers/spacy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def construct_prototype(self):
5353

5454
return prototype
5555

56-
def handler_predict(self, input_data, check_prototype):
56+
def handler_predict(self, input_data, check_prototype, **kw):
5757
"""
5858
Generates method for /predict endpoint in VetiverAPI
5959

vetiver/handlers/statsmodels.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class StatsmodelsHandler(BaseHandler):
2222
if sm_exists:
2323
pip_name = "statsmodels"
2424

25-
def handler_predict(self, input_data, check_prototype):
25+
def handler_predict(self, input_data, check_prototype, **kw):
2626
"""
2727
Generates method for /predict endpoint in VetiverAPI
2828
@@ -43,9 +43,7 @@ def handler_predict(self, input_data, check_prototype):
4343
if not sm_exists:
4444
raise ImportError("Cannot import `statsmodels`")
4545

46-
if isinstance(input_data, (list, pd.DataFrame)):
47-
prediction = self.model.predict(input_data)
48-
else:
49-
prediction = self.model.predict([input_data])
50-
51-
return prediction.tolist()
46+
input_data = (
47+
input_data if isinstance(input_data, (list, pd.DataFrame)) else [input_data]
48+
)
49+
return self.model.predict(input_data).tolist()

vetiver/handlers/torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class TorchHandler(BaseHandler):
2222
if torch_exists:
2323
pip_name = "torch"
2424

25-
def handler_predict(self, input_data, check_prototype):
25+
def handler_predict(self, input_data, check_prototype, **kw):
2626
"""
2727
Generates method for /predict endpoint in VetiverAPI
2828

vetiver/handlers/xgboost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class XGBoostHandler(BaseHandler):
2222
if xgb_exists:
2323
pip_name = "xgboost"
2424

25-
def handler_predict(self, input_data, check_prototype):
25+
def handler_predict(self, input_data, check_prototype, **kw):
2626
"""
2727
Generates method for /predict endpoint in VetiverAPI
2828

vetiver/mock.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
from sklearn.dummy import DummyRegressor
21
import pandas as pd
32
import numpy as np
43

4+
from sklearn.dummy import DummyRegressor
5+
from sklearn.linear_model import LogisticRegression
6+
7+
from .data import mtcars
8+
59

610
def get_mock_data():
711
"""Create mock data for testing
@@ -26,5 +30,17 @@ def get_mock_model():
2630
model : sklearn.dummy.DummyRegressor
2731
Arbitrary model for testing purposes
2832
"""
29-
model = DummyRegressor()
30-
return model
33+
return DummyRegressor()
34+
35+
36+
def get_mtcars_model():
37+
"""Create mock model for testing
38+
39+
Returns
40+
-------
41+
model : sklearn.dummy.DummyRegressor
42+
Arbitrary model for testing purposes
43+
"""
44+
return LogisticRegression(max_iter=1000, random_state=500).fit(
45+
mtcars.drop(columns="cyl"), mtcars["cyl"]
46+
)

vetiver/server.py

Lines changed: 67 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
from fastapi.exceptions import RequestValidationError
1616
from fastapi.openapi.utils import get_openapi
1717
from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse
18-
1918
from .helpers import api_data_to_frame, response_to_frame
19+
from .handlers.sklearn import SKLearnHandler
2020
from .meta import VetiverMeta
2121
from .utils import _jupyter_nb, get_workbench_path
2222
from .vetiver_model import VetiverModel
23+
from .types import SklearnPredictionTypes
2324

2425

2526
class VetiverAPI:
@@ -111,7 +112,6 @@ async def startup_event():
111112

112113
@app.get("/", include_in_schema=False)
113114
def docs_redirect():
114-
115115
redirect = "__docs__"
116116

117117
return RedirectResponse(redirect)
@@ -200,65 +200,94 @@ async def validation_exception_handler(request, exc):
200200

201201
return app
202202

203-
def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
204-
"""Create new POST endpoint that is aware of model input data
203+
def vetiver_post(
204+
self,
205+
endpoint_fx: Union[Callable, SklearnPredictionTypes],
206+
endpoint_name: str = None,
207+
**kw,
208+
):
209+
"""Define a new POST endpoint that utilizes the model's input data.
205210
206211
Parameters
207212
----------
208-
endpoint_fx : typing.Callable
209-
Custom function to be run at endpoint
213+
endpoint_fx
214+
: Union[typing.Callable, Literal["predict", "predict_proba", "predict_log_proba"]]
215+
A callable function that specifies the custom logic to execute when the
216+
endpoint is called. This function should take input data (e.g., a DataFrame
217+
or dictionary) and return the desired output(e.g., predictions or transformed
218+
data). For scikit-learn models, endpoint_fx can also be one of "predict",
219+
"predict_proba", or "predict_log_proba" if the model supports these methods.
220+
210221
endpoint_name : str
211-
Name of endpoint
222+
The name of the endpoint to be created.
212223
213224
Examples
214225
-------
215-
```{python}
226+
```python
216227
from vetiver import mock, VetiverModel, VetiverAPI
217228
X, y = mock.get_mock_data()
218229
model = mock.get_mock_model().fit(X, y)
219230
220-
v = VetiverModel(model = model, model_name = "model", prototype_data = X)
221-
v_api = VetiverAPI(model = v, check_prototype = True)
231+
v = VetiverModel(model=model, model_name="model", prototype_data=X)
232+
v_api = VetiverAPI(model=v, check_prototype=True)
222233
223234
def sum_values(x):
224235
return x.sum()
236+
225237
v_api.vetiver_post(sum_values, "sums")
226238
```
227239
"""
228-
if not endpoint_name:
229-
endpoint_name = endpoint_fx.__name__
230240

231-
if endpoint_fx.__doc__ is not None:
232-
api_desc = dedent(endpoint_fx.__doc__)
233-
else:
234-
api_desc = None
241+
if not isinstance(endpoint_fx, Callable):
242+
if endpoint_fx not in ["predict", "predict_proba", "predict_log_proba"]:
243+
raise ValueError(
244+
f"""
245+
Prediction type {endpoint_fx} not available.
246+
Available prediction types: {SklearnPredictionTypes}
247+
"""
248+
)
249+
if not isinstance(self.model.handler_predict.__self__, SKLearnHandler):
250+
raise ValueError(
251+
"""
252+
The 'endpoint_fx' parameter can only be a
253+
string when using scikit-learn models.
254+
"""
255+
)
256+
self.vetiver_post(
257+
self.model.handler_predict,
258+
endpoint_fx,
259+
check_prototype=self.check_prototype,
260+
prediction_type=endpoint_fx,
261+
)
262+
return
235263

236-
if self.check_prototype is True:
264+
endpoint_name = endpoint_name or endpoint_fx.__name__
265+
endpoint_doc = dedent(endpoint_fx.__doc__) if endpoint_fx.__doc__ else None
237266

238-
@self.app.post(
239-
urljoin("/", endpoint_name),
240-
name=endpoint_name,
241-
description=api_desc,
242-
)
243-
async def custom_endpoint(input_data: List[self.model.prototype]):
244-
_to_frame = api_data_to_frame(input_data)
245-
predictions = endpoint_fx(_to_frame, **kw)
246-
if isinstance(predictions, List):
247-
return {endpoint_name: predictions}
248-
else:
249-
return predictions
267+
# this must be split up this way to preserve the correct type hints for
268+
# the input_data schema validation via Pydantic + FastAPI
269+
input_data_type = (
270+
List[self.model.prototype] if self.check_prototype else Request
271+
)
250272

251-
else:
273+
@self.app.post(
274+
urljoin("/", endpoint_name),
275+
name=endpoint_name,
276+
description=endpoint_doc,
277+
)
278+
async def custom_endpoint(input_data: input_data_type):
252279

253-
@self.app.post(urljoin("/", endpoint_name))
254-
async def custom_endpoint(input_data: Request):
255-
served_data = await input_data.json()
256-
predictions = endpoint_fx(served_data, **kw)
280+
served_data = (
281+
api_data_to_frame(input_data)
282+
if self.check_prototype
283+
else await input_data.json()
284+
)
285+
predictions = endpoint_fx(served_data, **kw)
257286

258-
if isinstance(predictions, List):
259-
return {endpoint_name: predictions}
260-
else:
261-
return predictions
287+
if isinstance(predictions, List):
288+
return {endpoint_name: predictions}
289+
else:
290+
return predictions
262291

263292
def run(self, port: int = 8000, host: str = "127.0.0.1", quiet_open=False, **kw):
264293
"""

0 commit comments

Comments
 (0)