Skip to content

Commit a25fb07

Browse files
authored
Merge pull request #87 from isabelizimm/handlers-update
handlers to register themselves
2 parents f37b86b + 927e80d commit a25fb07

10 files changed

Lines changed: 143 additions & 149 deletions

File tree

.github/workflows/tests.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,20 @@ jobs:
6464
- name: Run tests
6565
run: |
6666
pytest vetiver -m 'rsc_test'
67+
68+
test-no-torch:
69+
name: "Test no-torch"
70+
runs-on: ubuntu-latest
71+
steps:
72+
- uses: actions/checkout@v2
73+
- uses: actions/setup-python@v2
74+
with:
75+
python-version: 3.8
76+
- name: Install dependencies
77+
run: |
78+
python -m pip install --upgrade pip
79+
python -m pip install -e .[dev]
80+
81+
- name: Run tests
82+
run: |
83+
pytest vetiver/tests/test_sklearn.py

docs/source/advancedusage/custom_handler.md

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,45 @@
11
# Custom Handlers
22

3-
There are two different ways that vetiver supports flexible handling for models that do not work automatically with the vetiver framework. The first way is with [new model types,](#new-model-type) where there is no current implementation for the type of model you would like to deploy. The second way is when you would like to implement a current handler, but [in a different way](#different-model-implementation). In either case, you *must* make a custom handler from the base `VetiverHandler`. A minimal custom handler could look like the following:
3+
There are two different ways that vetiver supports flexible handling for models that do not work automatically with the vetiver framework. The first way is with new model types where there is no current implementation for the type of model you would like to deploy. The second way is when you would like to implement a current handler, but in a different way. In either case, you should create a custom handler from vetiver's `BaseHandler()`. At a minimum, you must give the type of your model via `model_type` how predictions should be made, via the method `handler_predict()`. Then, initialize your handler with your model, and pass the object into `VetiverModel`.
4+
5+
This example shows a custom handler of `newmodeltype` type.
46

57
```python
6-
from vetiver.handlers.base import VetiverHandler
8+
from vetiver.handlers.base import BaseHandler
79

8-
class SampleCustomHandler(VetiverHandler):
10+
class CustomHandler(BaseHandler):
911
def __init__(model, ptype_data):
1012
super().__init__(model, ptype_data)
1113

12-
def handler_predict(self, input_data, check_ptype):
14+
model_type = staticmethod(lambda: newmodeltype)
15+
16+
def handler_predict(self, input_data, check_ptype: bool):
1317
"""
14-
handler_predict defines how to make predictions from your model
18+
Generates method for /predict endpoint in VetiverAPI
19+
20+
The `handler_predict` function executes at each API call. Use this
21+
function for calling `predict()` and any other tasks that must be executed at each API call.
22+
23+
Parameters
24+
----------
25+
input_data:
26+
Test data
27+
check_ptype: bool
28+
Whether the ptype should be enforced
29+
30+
Returns
31+
-------
32+
prediction
33+
Prediction from model
1534
"""
1635
# your code here
17-
```
36+
prediction = model.fancy_new_predict(input_data)
1837

19-
## New model type
20-
If your model type is not supported by vetiver, you should create and then register the handler using [single dispatch](https://docs.python.org/3/library/functools.html#functools.singledispatch). Once the new type is registered, you are able to use `VetiverModel()` as normal. Here is a template for such a function:
21-
22-
```python
23-
from vetiver.handlers._interface import create_handler
38+
return prediction
2439

25-
@create_handler.register
26-
def _(model: {_model_type}, ptype_data):
27-
return SampleCustomHandler(model, ptype_data)
40+
new_model = CustomHandler(model, ptype_data)
2841

29-
VetiverModel(your_model, "your_model")
42+
VetiverModel(new_model, "custom_model")
3043
```
3144

3245
If your model is a common type, please consider [submitting a pull request](https://github.com/rstudio/vetiver-python/pulls).
33-
34-
## Different model implementation
35-
If your model's prediction function is different than vetiver's, you should create a custom handler with a `handler_predict` method to make predictions. Then, initialize your handler with your model, and pass the object into `VetiverModel`.
36-
37-
```python
38-
new_model = SampleCustomHandler(your_model, your_ptype_data)
39-
40-
VetiverModel(new_model, "your_model")
41-
```

vetiver/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
from .meta import * # noqa
1212
from .write_docker import write_docker # noqa
1313
from .write_fastapi import write_app # noqa
14-
from .handlers._interface import create_handler, InvalidModelError # noqa
15-
from .handlers.base import VetiverHandler # noqa
14+
from .handlers.base import BaseHandler, create_handler, InvalidModelError # noqa
1615
from .handlers.sklearn import SKLearnHandler # noqa
1716
from .handlers.torch import TorchHandler # noqa
1817
from .rsconnect import deploy_rsconnect # noqa

vetiver/handlers/_interface.py

Lines changed: 0 additions & 99 deletions
This file was deleted.

vetiver/handlers/base.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,62 @@
1-
from abc import ABCMeta
1+
from vetiver.handlers import base
2+
from functools import singledispatch
3+
from contextlib import suppress
24

35
from ..ptype import vetiver_create_ptype
46
from ..meta import _model_meta
57

68

7-
class VetiverHandler(metaclass=ABCMeta):
9+
class InvalidModelError(Exception):
10+
"""
11+
Throw an error if `model` is not registered.
12+
"""
13+
14+
def __init__(
15+
self,
16+
message="The `model` argument must be a scikit-learn or torch model.",
17+
):
18+
self.message = message
19+
super().__init__(self.message)
20+
21+
22+
@singledispatch
23+
def create_handler(model, ptype_data):
24+
"""check for model type to handle prediction
25+
26+
Parameters
27+
----------
28+
model: object
29+
Description of parameter `x`.
30+
ptype_data : object
31+
An object with information (data) whose layout is to be determined.
32+
33+
Returns
34+
-------
35+
handler
36+
Handler class for specified model type
37+
38+
39+
Examples
40+
--------
41+
>>> import vetiver
42+
>>> X, y = vetiver.mock.get_mock_data()
43+
>>> model = vetiver.mock.get_mock_model()
44+
>>> handler = vetiver.create_handler(model, X)
45+
>>> handler.describe()
46+
"Scikit-learn <class 'sklearn.dummy.DummyRegressor'> model"
47+
"""
48+
49+
raise InvalidModelError(
50+
"Model must be an sklearn or torch model, or a \
51+
custom handler must be used. See the docs for more info on custom handlers. \
52+
https://rstudio.github.io/vetiver-python/advancedusage/custom_handler.html"
53+
)
54+
55+
56+
# BaseHandler uses create_handler to register subclasses based on model_class
57+
58+
59+
class BaseHandler:
860
"""Base handler class for creating VetiverModel of different type.
961
1062
Parameters
@@ -15,6 +67,12 @@ class VetiverHandler(metaclass=ABCMeta):
1567
An object with information (data) whose layout is to be determined.
1668
"""
1769

70+
@classmethod
71+
def __init_subclass__(cls, **kwargs):
72+
super().__init_subclass__(**kwargs)
73+
with suppress(AttributeError, NameError):
74+
create_handler.register(cls.model_class(), cls)
75+
1876
def __init__(self, model, ptype_data):
1977
self.model = model
2078
self.ptype_data = ptype_data
@@ -79,3 +137,15 @@ def handler_predict(self, input_data, check_ptype):
79137
Prediction from model
80138
"""
81139
...
140+
141+
142+
# BaseHandler for subclassing, Handler for new model types
143+
Handler = BaseHandler
144+
145+
146+
@create_handler.register
147+
def _(model: base.BaseHandler, ptype_data):
148+
if model.ptype_data is None and ptype_data is not None:
149+
model.ptype_data = ptype_data
150+
151+
return model

vetiver/handlers/sklearn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
import sklearn
33

44
from ..meta import _model_meta
5-
from .base import VetiverHandler
5+
from .base import BaseHandler
66

77

8-
class SKLearnHandler(VetiverHandler):
8+
class SKLearnHandler(BaseHandler):
99
"""Handler class for creating VetiverModels with sklearn.
1010
1111
Parameters
@@ -14,7 +14,7 @@ class SKLearnHandler(VetiverHandler):
1414
a trained sklearn model
1515
"""
1616

17-
base_class = sklearn.base.BaseEstimator
17+
model_class = staticmethod(lambda: sklearn.base.BaseEstimator)
1818

1919
def __init__(self, model, ptype_data):
2020
super().__init__(model, ptype_data)
@@ -54,7 +54,7 @@ def handler_predict(self, input_data, check_ptype):
5454
Prediction from model
5555
"""
5656

57-
if check_ptype == True:
57+
if check_ptype:
5858
if isinstance(input_data, pd.DataFrame):
5959
prediction = self.model.predict(input_data)
6060
else:

vetiver/handlers/torch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

33
from ..meta import _model_meta
4-
from .base import VetiverHandler
4+
from .base import BaseHandler
55

66
torch_exists = True
77
try:
@@ -10,7 +10,7 @@
1010
torch_exists = False
1111

1212

13-
class TorchHandler(VetiverHandler):
13+
class TorchHandler(BaseHandler):
1414
"""Handler class for creating VetiverModels with torch.
1515
1616
Parameters
@@ -19,7 +19,7 @@ class TorchHandler(VetiverHandler):
1919
a trained torch model
2020
"""
2121

22-
base_class = torch.nn.Module
22+
model_class = staticmethod(lambda: torch.nn.Module)
2323

2424
def __init__(self, model, ptype_data):
2525
super().__init__(model, ptype_data)
@@ -59,7 +59,7 @@ def handler_predict(self, input_data, check_ptype):
5959
Prediction from model
6060
"""
6161
if torch_exists:
62-
if check_ptype == True:
62+
if check_ptype:
6363
input_data = np.array(input_data, dtype=np.array(self.ptype_data).dtype)
6464
prediction = self.model(torch.from_numpy(input_data))
6565

vetiver/tests/test_custom_handler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
import pydantic
33
import pandas as pd
44

5-
from vetiver import mock, VetiverModel, VetiverHandler
5+
from vetiver import mock, VetiverModel, BaseHandler
66

77

8-
class CustomHandler(VetiverHandler):
8+
class CustomHandler(BaseHandler):
99
def __init__(self, model, ptype_data):
1010
super().__init__(model, ptype_data)
1111

12+
model_type = staticmethod(lambda: sklearn.dummy.DummyRegressor)
13+
1214
def handler_predict(self, input_data, check_ptype):
1315
if check_ptype is True:
1416
if isinstance(input_data, pd.DataFrame):

0 commit comments

Comments
 (0)