Skip to content

Commit dbd6495

Browse files
committed
suggestions from review
1 parent d660102 commit dbd6495

7 files changed

Lines changed: 55 additions & 46 deletions

File tree

docs/source/advancedusage/custom_handler.md

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,45 @@
11
# Custom Handlers
22

3-
There are two different ways that vetiver supports flexible handling for models that do not work automatically with the vetiver framework. The first way is with [new model types,](#new-model-type) where there is no current implementation for the type of model you would like to deploy. The second way is when you would like to implement a current handler, but [in a different way](#different-model-implementation). In either case, you *must* make a custom handler from the base `VetiverHandler`. A minimal custom handler could look like the following:
3+
There are two different ways that vetiver supports flexible handling for models that do not work automatically with the vetiver framework. The first way is with new model types where there is no current implementation for the type of model you would like to deploy. The second way is when you would like to implement a current handler, but in a different way. In either case, you should create a custom handler from vetiver's `BaseHandler()`. At a minimum, you must give the type of your model via `model_type` how predictions should be made, via the method `handler_predict()`. Then, initialize your handler with your model, and pass the object into `VetiverModel`.
4+
5+
This example shows a custom handler of `newmodeltype` type.
46

57
```python
6-
from vetiver.handlers.base import VetiverHandler
8+
from vetiver.handlers.base import BaseHandler
79

8-
class SampleCustomHandler(VetiverHandler):
10+
class CustomHandler(BaseHandler):
911
def __init__(model, ptype_data):
1012
super().__init__(model, ptype_data)
1113

12-
def handler_predict(self, input_data, check_ptype):
14+
model_type = staticmethod(lambda: newmodeltype)
15+
16+
def handler_predict(self, input_data, check_ptype: bool):
1317
"""
14-
handler_predict defines how to make predictions from your model
18+
Generates method for /predict endpoint in VetiverAPI
19+
20+
The `handler_predict` function executes at each API call. Use this
21+
function for calling `predict()` and any other tasks that must be executed at each API call.
22+
23+
Parameters
24+
----------
25+
input_data:
26+
Test data
27+
check_ptype: bool
28+
Whether the ptype should be enforced
29+
30+
Returns
31+
-------
32+
prediction
33+
Prediction from model
1534
"""
1635
# your code here
17-
```
36+
prediction = model.fancy_new_predict(input_data)
1837

19-
## New model type
20-
If your model type is not supported by vetiver, you should create and then register the handler using [single dispatch](https://docs.python.org/3/library/functools.html#functools.singledispatch). Once the new type is registered, you are able to use `VetiverModel()` as normal. Here is a template for such a function:
21-
22-
```python
23-
from vetiver.handlers._interface import create_handler
38+
return prediction
2439

25-
@create_handler.register
26-
def _(model: {_model_type}, ptype_data):
27-
return SampleCustomHandler(model, ptype_data)
40+
new_model = CustomHandler(model, ptype_data)
2841

29-
VetiverModel(your_model, "your_model")
42+
VetiverModel(new_model, "integer_model")
3043
```
3144

3245
If your model is a common type, please consider [submitting a pull request](https://github.com/rstudio/vetiver-python/pulls).
33-
34-
## Different model implementation
35-
If your model's prediction function is different than vetiver's, you should create a custom handler with a `handler_predict` method to make predictions. Then, initialize your handler with your model, and pass the object into `VetiverModel`.
36-
37-
```python
38-
new_model = SampleCustomHandler(your_model, your_ptype_data)
39-
40-
VetiverModel(new_model, "your_model")
41-
```

vetiver/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .meta import * # noqa
1515
from .write_docker import write_docker # noqa
1616
from .write_fastapi import write_app # noqa
17-
from .handlers.base import VetiverHandler, create_handler, InvalidModelError # noqa
17+
from .handlers.base import BaseHandler, create_handler, InvalidModelError # noqa
1818
from .handlers.sklearn import SKLearnHandler # noqa
1919
from .handlers.torch import TorchHandler # noqa
2020
from .rsconnect import deploy_rsconnect # noqa

vetiver/handlers/base.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ def create_handler(model, ptype_data):
5353
)
5454

5555

56-
class VetiverHandler:
56+
# BaseHandler uses create_handler to register subclasses based on model_class
57+
58+
59+
class BaseHandler:
5760
"""Base handler class for creating VetiverModel of different type.
5861
5962
Parameters
@@ -68,11 +71,7 @@ class VetiverHandler:
6871
def __init_subclass__(cls, **kwargs):
6972
super().__init_subclass__(**kwargs)
7073
with suppress(AttributeError, NameError):
71-
create_handler.register(cls.base_class(), cls)
72-
73-
# def __new__(cls, value=None):
74-
# implementation_cls = create_handler.registry[type(value)]
75-
# return super().__new__(implementation_cls)
74+
create_handler.register(cls.model_class(), cls)
7675

7776
def __init__(self, model, ptype_data):
7877
self.model = model
@@ -140,8 +139,12 @@ def handler_predict(self, input_data, check_ptype):
140139
...
141140

142141

142+
# BaseHandler for subclassing, Handler for new model types
143+
Handler = BaseHandler
144+
145+
143146
@create_handler.register
144-
def _(model: base.VetiverHandler, ptype_data):
147+
def _(model: base.BaseHandler, ptype_data):
145148
if model.ptype_data is None and ptype_data is not None:
146149
model.ptype_data = ptype_data
147150

vetiver/handlers/sklearn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
import sklearn
33

44
from ..meta import _model_meta
5-
from .base import VetiverHandler
5+
from .base import BaseHandler
66

77

8-
class SKLearnHandler(VetiverHandler):
8+
class SKLearnHandler(BaseHandler):
99
"""Handler class for creating VetiverModels with sklearn.
1010
1111
Parameters
@@ -14,7 +14,7 @@ class SKLearnHandler(VetiverHandler):
1414
a trained sklearn model
1515
"""
1616

17-
base_class = staticmethod(lambda: sklearn.base.BaseEstimator)
17+
model_class = staticmethod(lambda: sklearn.base.BaseEstimator)
1818

1919
def __init__(self, model, ptype_data):
2020
super().__init__(model, ptype_data)

vetiver/handlers/torch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

33
from ..meta import _model_meta
4-
from .base import VetiverHandler
4+
from .base import BaseHandler
55

66
torch_exists = True
77
try:
@@ -10,7 +10,7 @@
1010
torch_exists = False
1111

1212

13-
class TorchHandler(VetiverHandler):
13+
class TorchHandler(BaseHandler):
1414
"""Handler class for creating VetiverModels with torch.
1515
1616
Parameters
@@ -19,7 +19,7 @@ class TorchHandler(VetiverHandler):
1919
a trained torch model
2020
"""
2121

22-
base_class = staticmethod(lambda: torch.nn.Module)
22+
model_class = staticmethod(lambda: torch.nn.Module)
2323

2424
def __init__(self, model, ptype_data):
2525
super().__init__(model, ptype_data)

vetiver/tests/test_custom_handler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
import pydantic
33
import pandas as pd
44

5-
from vetiver import mock, VetiverModel, VetiverHandler
5+
from vetiver import mock, VetiverModel, BaseHandler
66

77

8-
class CustomHandler(VetiverHandler):
8+
class CustomHandler(BaseHandler):
99
def __init__(self, model, ptype_data):
1010
super().__init__(model, ptype_data)
1111

12+
model_type = staticmethod(lambda: sklearn.dummy.DummyRegressor)
13+
1214
def handler_predict(self, input_data, check_ptype):
1315
if check_ptype is True:
1416
if isinstance(input_data, pd.DataFrame):

vetiver/tests/test_pytorch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from vetiver.vetiver_model import VetiverModel
2-
from vetiver import VetiverAPI
3-
from fastapi.testclient import TestClient
41
import pytest
5-
import numpy as np
62

7-
torch = pytest.importorskip("torch")
3+
torch = pytest.importorskip("torch", reason="torch library not installed")
84

9-
import torch # noqa
5+
import numpy as np # noqa
6+
from fastapi.testclient import TestClient # noqa
7+
8+
from vetiver.vetiver_model import VetiverModel # noqa
9+
from vetiver import VetiverAPI # noqa
1010

1111

1212
def _build_torch_v():

0 commit comments

Comments
 (0)