Skip to content

Commit 8c8e5ed

Browse files
authored
Merge pull request #57 from isabelizimm/dev-sd-handler
handlers to single dispatch
2 parents ce9b5b8 + 700fc64 commit 8c8e5ed

12 files changed

Lines changed: 328 additions & 100 deletions

File tree

.github/workflows/docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
- name: Install dependencies
2121
run: |
2222
python -m pip install --upgrade pip
23-
python -m pip install -e .[dev]
23+
python -m pip install -e .[dev,torch]
2424
- name: build docs
2525
run: |
2626
cd docs
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Custom Handlers
2+
3+
There are two different ways that vetiver supports flexible handling for models that do not work automatically with the vetiver framework. The first way is with [new model types,](#new-model-type) where there is no current implementation for the type of model you would like to deploy. The second way is when you would like to implement a current handler, but [in a different way](#different-model-implementation). In either case, you *must* make a custom handler from the base `VetiverHandler`. A minimal custom handler could look like the following:
4+
5+
```python
6+
from vetiver.handlers.base import VetiverHandler
7+
8+
class SampleCustomHandler(VetiverHandler):
9+
def __init__(model, ptype_data):
10+
super().__init__(model, ptype_data)
11+
12+
def handler_predict(self, input_data, check_ptype):
13+
"""
14+
handler_predict defines how to make predictions from your model
15+
"""
16+
# your code here
17+
```
18+
19+
## New model type
20+
If your model type is not supported by vetiver, you should create and then register the handler using [single dispatch](https://docs.python.org/3/library/functools.html#functools.singledispatch). Once the new type is registered, you are able to use `VetiverModel()` as normal. Here is a template for such a function:
21+
22+
```python
23+
from vetiver.handlers._interface import create_handler
24+
25+
@create_handler.register
26+
def _(model: {_model_type}, ptype_data):
27+
return SampleCustomHandler(model, ptype_data)
28+
29+
VetiverModel(your_model, "your_model")
30+
```
31+
32+
If your model is a common type, please consider [submitting a pull request](https://github.com/rstudio/vetiver-python/pulls).
33+
34+
## Different model implementation
35+
If your model's prediction function is different than vetiver's, you should create a custom handler with a `handler_predict` method to make predictions. Then, initialize your handler with your model, and pass the object into `VetiverModel`.
36+
37+
```python
38+
new_model = SampleCustomHandler(your_model, your_ptype_data)
39+
40+
VetiverModel(new_model, "your_model")
41+
```

docs/source/conf.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
# add these directories to sys.path here. If the directory is relative to the
1111
# documentation root, use os.path.abspath to make it absolute, like shown here.
1212
#
13-
# import os
14-
# import sys
15-
# sys.path.insert(0, os.path.abspath('.'))
13+
import os
14+
import sys
15+
sys.path.insert(0, os.path.abspath('.'))
16+
from vetiver import __version__
1617

1718

1819
# -- Project information -----------------------------------------------------
@@ -22,7 +23,7 @@
2223
author = "Isabel Zimmerman"
2324

2425
# The full version, including alpha/beta/rc tags
25-
release = "0.1.3"
26+
release = __version__
2627

2728

2829
# -- General configuration ---------------------------------------------------
@@ -64,6 +65,13 @@
6465

6566
}
6667

68+
source_suffix = {
69+
'.rst': 'restructuredtext',
70+
'.txt': 'markdown',
71+
'.md': 'markdown',
72+
}
73+
myst_heading_anchors = 2
74+
6775
html_logo = "../figures/logo.png"
6876
html_favicon = "../figures/logo.png"
6977

docs/source/index.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ Version
4242
Deploy
4343
==================
4444

45-
.. currentmodule:: vetiver
46-
4745
.. autosummary::
4846
:toctree: reference/
4947
:caption: Deploy
@@ -56,3 +54,10 @@ Deploy
5654
~load_pkgs
5755
~vetiver_write_app
5856
~vetiver_write_docker
57+
58+
Advanced Usage
59+
==================
60+
.. toctree::
61+
advancedusage/custom_handler.md
62+
:caption: Advanced Usage
63+

vetiver/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .meta import *
1515
from .write_docker import *
1616
from .write_fastapi import *
17-
from .handlers._interface import *
18-
from .handlers.sklearn_vt import *
19-
from .handlers.pytorch_vt import *
17+
from .handlers._interface import create_handler, InvalidModelError
18+
from .handlers.base import VetiverHandler
19+
from .handlers.sklearn import SKLearnHandler
20+
from .handlers.torch import TorchHandler

vetiver/handlers/_interface.py

Lines changed: 82 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,96 @@
1-
from vetiver.handlers import pytorch_vt, sklearn_vt
2-
import sklearn
1+
from vetiver.handlers import torch, sklearn, base
2+
from functools import singledispatch
33

4-
torch_exists = True
5-
try:
6-
import torch
7-
except ImportError:
8-
torch_exists = False
4+
from vetiver.ptype import vetiver_create_ptype
95

10-
def create_translator(model, ptype_data):
6+
class InvalidModelError(Exception):
7+
"""
8+
Throw an error if `model` is not
9+
from scikit-learn or torch
10+
"""
11+
12+
def __init__(
13+
self,
14+
message="The `model` argument must be a scikit-learn or torch model.",
15+
):
16+
self.message = message
17+
super().__init__(self.message)
18+
19+
CREATE_PTYPE_TPL = """\
20+
Failed to create a handler from model of \
21+
type {_model_type}. If your model is not one of \
22+
(scikit-learn, torch), you should create and register \
23+
the handler. Here is a template for such a function: \
24+
from vetiver.handlers._interface import create_handler
25+
from vetiver.handlers.base import VetiverHandler
26+
27+
class CustomTemplateHandler(VetiverHandler):
28+
def __init__(model, ptype_data):
29+
super().__init__(model, ptype_data)
30+
31+
def vetiver_create_meta(
32+
user: list = None,
33+
version: str = None,
34+
url: str = None,
35+
required_pkgs: list = []):
36+
\"""
37+
Create metadata for model. This method should include the required
38+
packages necessary to create a prediction.
39+
\"""
40+
required_pkgs = required_pkgs + ["name_of_modeling_package"]
41+
meta = vetiver_meta(user, version, url, required_pkgs)
42+
43+
return meta
44+
45+
def handler_predict(self, input_data, check_ptype):
46+
\"""
47+
handler_predict should define how to make predictions from your model
48+
\"""
49+
...
50+
51+
@vetiver_create_ptype.register
52+
def _(model: {_model_type}, ptype_data):
53+
return CustomTemplateHandler(model, ptype_data)
54+
55+
If your datatype is a common type, please consider submitting \
56+
a pull request.
57+
"""
58+
59+
@singledispatch
60+
def create_handler(model, ptype_data):
1161
"""check for model type to handle prediction
1262
1363
Parameters
1464
----------
15-
model
65+
model: object
1666
Description of parameter `x`.
67+
ptype_data : object
68+
An object with information (data) whose layout is to be determined.
1769
1870
Returns
1971
-------
20-
pytorch_vt.TorchHandler or sklearn_vt.SKLearnHandler
72+
handler
2173
Handler class for specified model type
74+
75+
Examples
76+
--------
77+
>>> import vetiver
78+
>>> X, y = vetiver.mock.get_mock_data()
79+
>>> model = vetiver.mock.get_mock_model()
80+
>>> handler = vetiver.create_handler(model, X)
81+
>>> handler.describe()
82+
"Scikit-learn <class 'sklearn.dummy.DummyRegressor'> model"
2283
"""
23-
if torch_exists:
24-
if isinstance(model, torch.nn.Module):
25-
return pytorch_vt.TorchHandler(model, ptype_data)
2684

27-
if isinstance(model, sklearn.base.BaseEstimator):
28-
return sklearn_vt.SKLearnHandler(model, ptype_data)
85+
raise InvalidModelError(message=CREATE_PTYPE_TPL.format(_model_type=type(model)))
86+
87+
create_handler.register(sklearn.SKLearnHandler.base_class, sklearn.SKLearnHandler)
88+
89+
create_handler.register(torch.TorchHandler.base_class, torch.TorchHandler)
2990

30-
else:
31-
raise NotImplementedError
91+
@create_handler.register
92+
def _(model:base.VetiverHandler, ptype_data):
93+
if model.ptype_data is None and ptype_data is not None:
94+
model.ptype_data = ptype_data
95+
96+
return model

vetiver/handlers/base.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from ..ptype import vetiver_create_ptype
2+
from ..meta import vetiver_meta
3+
from abc import ABCMeta
4+
5+
class VetiverHandler(metaclass=ABCMeta):
6+
"""Base handler class for creating VetiverModel of different type.
7+
8+
Parameters
9+
----------
10+
model :
11+
a trained model
12+
ptype_data :
13+
An object with information (data) whose layout is to be determined.
14+
"""
15+
16+
def __init__(self, model, ptype_data):
17+
self.model = model
18+
self.ptype_data = ptype_data
19+
20+
21+
def describe(self):
22+
"""Create description for model"""
23+
desc = f"{self.model.__class__} model"
24+
return desc
25+
26+
def create_meta(
27+
user: list = None,
28+
version: str = None,
29+
url: str = None,
30+
required_pkgs: list = [],
31+
):
32+
"""Create metadata for sklearn model"""
33+
meta = vetiver_meta(user, version, url, required_pkgs)
34+
35+
return meta
36+
37+
def construct_ptype(self):
38+
"""Create data prototype for torch model
39+
40+
Parameters
41+
----------
42+
ptype_data : pd.DataFrame, np.ndarray, or None
43+
Training data to create ptype
44+
45+
Returns
46+
-------
47+
ptype : pd.DataFrame or None
48+
Zero-row DataFrame for storing data types
49+
"""
50+
ptype = vetiver_create_ptype(self.ptype_data)
51+
return ptype
52+
53+
def handler_startup():
54+
"""Include required packages for prediction
55+
56+
The `handler_startup` function executes when the API starts. Use this
57+
function for tasks like loading packages.
58+
"""
59+
...
60+
61+
62+
def handler_predict(self, input_data, check_ptype):
63+
"""Generates method for /predict endpoint in VetiverAPI
64+
65+
The `handler_predict` function executes at each API call. Use this
66+
function for calling `predict()` and any other tasks that must be executed
67+
at each API call.
68+
69+
Parameters
70+
----------
71+
input_data:
72+
Data used to generate prediction
73+
check_ptype:
74+
If type should be checked against `ptype` or not
75+
76+
Returns
77+
-------
78+
prediction
79+
Prediction from model
80+
"""
81+
...
Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,29 @@
1-
from ..ptype import vetiver_create_ptype
21
import pandas as pd
3-
from ..meta import vetiver_meta
42
import sklearn
53

6-
class SKLearnHandler:
4+
from .base import VetiverHandler
5+
from ..meta import vetiver_meta
6+
7+
class SKLearnHandler(VetiverHandler):
78
"""Handler class for creating VetiverModels with sklearn.
89
910
Parameters
1011
----------
1112
model : sklearn.base.BaseEstimator
1213
a trained sklearn model
1314
"""
14-
15+
base_class = sklearn.base.BaseEstimator
16+
1517
def __init__(self, model, ptype_data):
16-
self.model = model
17-
self.ptype_data = ptype_data
18+
super().__init__(model, ptype_data)
1819

19-
def create_description(self):
20+
def describe(self):
2021
"""Create description for sklearn model
2122
"""
2223
desc = f"Scikit-learn {self.model.__class__} model"
2324
return desc
2425

25-
def vetiver_create_meta(
26+
def construct_meta(
2627
user: list = None,
2728
version: str = None,
2829
url: str = None,
@@ -34,30 +35,6 @@ def vetiver_create_meta(
3435

3536
return meta
3637

37-
def ptype(self):
38-
"""Create data prototype for torch model
39-
40-
Parameters
41-
----------
42-
ptype_data : pd.DataFrame, np.ndarray, or None
43-
Training data to create ptype
44-
45-
Returns
46-
-------
47-
ptype : pd.DataFrame or None
48-
Zero-row DataFrame for storing data types
49-
"""
50-
ptype = vetiver_create_ptype(self.ptype_data)
51-
return ptype
52-
53-
def handler_startup():
54-
"""Include required packages for prediction
55-
56-
The `handler_startup` function executes when the API starts. Use this
57-
function for tasks like loading packages.
58-
"""
59-
...
60-
6138

6239
def handler_predict(self, input_data, check_ptype):
6340
"""Generates method for /predict endpoint in VetiverAPI

0 commit comments

Comments
 (0)