Skip to content

Commit f233463

Browse files
committed
handlers to single dispatch
1 parent c050d49 commit f233463

7 files changed

Lines changed: 174 additions & 24 deletions

File tree

vetiver/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@
1515
from .write_docker import *
1616
from .write_fastapi import *
1717
from .handlers._interface import *
18-
from .handlers.sklearn_vt import *
19-
from .handlers.pytorch_vt import *
18+
from .handlers.scikitlearn import *
19+
from .handlers.pytorch import *

vetiver/handlers/_interface.py

Lines changed: 81 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,100 @@
1-
from vetiver.handlers import pytorch_vt, sklearn_vt
2-
import sklearn
1+
from typing import Any
2+
from vetiver.handlers import pytorch, scikitlearn
3+
from functools import singledispatch
4+
import sklearn
35

46
torch_exists = True
57
try:
68
import torch
79
except ImportError:
810
torch_exists = False
911

10-
def create_translator(model, ptype_data):
12+
class InvalidModelError(Exception):
13+
"""
14+
Throw an error if `model` is not
15+
from scikit-learn or torch
16+
"""
17+
18+
def __init__(
19+
self,
20+
message="The `model` argument must be a scikit-learn or torch model.",
21+
):
22+
self.message = message
23+
super().__init__(self.message)
24+
25+
CREATE_PTYPE_TPL = """\
26+
Failed to create a handler from model of \
27+
type {_model_type}. If your model is not one of \
28+
(scikit-learn, torch), you should create and register \
29+
the handler. Here is a template for such a function: \
30+
from pydantic import create_model
31+
from vetiver.handlers._interface import create_handler
32+
from vetiver.handlers.base import VetiverHandler
33+
34+
class CustomTemplateHandler(VetiverHandler):
35+
def __init__(model, ptype_data):
36+
super().__init__(model, ptype_data)
37+
38+
def vetiver_create_meta(
39+
user: list = None,
40+
version: str = None,
41+
url: str = None,
42+
required_pkgs: list = []):
43+
\"""
44+
Create metadata for model. This method should include the required
45+
packages necessary to create a prediction.
46+
\"""
47+
required_pkgs = required_pkgs + ["name_of_modeling_package"]
48+
meta = vetiver_meta(user, version, url, required_pkgs)
49+
50+
return meta
51+
52+
def handler_predict(self, input_data, check_ptype):
53+
\"""
54+
handler_predict should define how to make predictions from your model
55+
\"""
56+
...
57+
58+
@vetiver_create_ptype.register
59+
def _(model: {_model_type}, ptype_data):
60+
return CustomTemplateHandler(model, ptype_data)
61+
62+
If your datatype is a common type, please consider submitting \
63+
a pull request.
64+
"""
65+
66+
@singledispatch
67+
def create_handler(model, ptype_data):
1168
"""check for model type to handle prediction
1269
1370
Parameters
1471
----------
15-
model
72+
model: object
1673
Description of parameter `x`.
74+
ptype_data : object
75+
An object with information (data) whose layout is to be determined.
1776
1877
Returns
1978
-------
20-
pytorch_vt.TorchHandler or sklearn_vt.SKLearnHandler
79+
handler
2180
Handler class for specified model type
81+
82+
Examples
83+
--------
84+
>>> import vetiver
85+
>>> X, y = vetiver.mock.get_mock_data()
86+
>>> model = vetiver.mock.get_mock_model()
87+
>>> handler = vetiver.create_handler(model, X)
88+
>>> handler.create_description()
89+
Scikit-learn <class 'sklearn.dummy.DummyRegressor'> model
2290
"""
23-
if torch_exists:
24-
if isinstance(model, torch.nn.Module):
25-
return pytorch_vt.TorchHandler(model, ptype_data)
91+
raise InvalidModelError(message=CREATE_PTYPE_TPL.format(_model_type=type(model)))
2692

27-
if isinstance(model, sklearn.base.BaseEstimator):
28-
return sklearn_vt.SKLearnHandler(model, ptype_data)
93+
@create_handler.register
94+
def _(model: sklearn.base.BaseEstimator, ptype_data: Any):
95+
return scikitlearn.SKLearnHandler(model, ptype_data)
2996

30-
else:
31-
raise NotImplementedError
97+
if torch_exists:
98+
@create_handler.register
99+
def _(model: torch.nn.Module, ptype_data: Any):
100+
return pytorch.TorchHandler(model, ptype_data)

vetiver/handlers/base.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from ..ptype import vetiver_create_ptype
2+
from ..meta import vetiver_meta
3+
4+
class VetiverHandler:
5+
"""Base handler class for creating VetiverModel of different type.
6+
7+
Parameters
8+
----------
9+
model :
10+
a trained model
11+
ptype_data:
12+
An object with information (data) whose layout is to be determined.
13+
"""
14+
15+
def __init__(self, model, ptype_data):
16+
self.model = model
17+
self.ptype_data = ptype_data
18+
19+
def create_description(self):
20+
"""Create description for model"""
21+
desc = f"{self.model.__class__} model"
22+
return desc
23+
24+
def vetiver_create_meta(
25+
user: list = None,
26+
version: str = None,
27+
url: str = None,
28+
required_pkgs: list = [],
29+
):
30+
"""Create metadata for sklearn model"""
31+
meta = vetiver_meta(user, version, url, required_pkgs)
32+
33+
return meta
34+
35+
def ptype(self):
36+
"""Create data prototype for torch model
37+
38+
Parameters
39+
----------
40+
ptype_data : pd.DataFrame, np.ndarray, or None
41+
Training data to create ptype
42+
43+
Returns
44+
-------
45+
ptype : pd.DataFrame or None
46+
Zero-row DataFrame for storing data types
47+
"""
48+
ptype = vetiver_create_ptype(self.ptype_data)
49+
return ptype
50+
51+
def handler_startup():
52+
"""Include required packages for prediction
53+
54+
The `handler_startup` function executes when the API starts. Use this
55+
function for tasks like loading packages.
56+
"""
57+
...
58+
59+
60+
def handler_predict(self, input_data, check_ptype):
61+
"""Generates method for /predict endpoint in VetiverAPI
62+
63+
The `handler_predict` function executes at each API call. Use this
64+
function for calling `predict()` and any other tasks that must be executed
65+
at each API call.
66+
67+
Parameters
68+
----------
69+
input_data:
70+
Data used to generate prediction
71+
check_ptype:
72+
If type should be checked against `ptype` or not
73+
74+
Returns
75+
-------
76+
prediction
77+
Prediction from model
78+
"""
79+
...
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from vetiver.handlers.base import VetiverHandler
12
from ..meta import vetiver_meta
23
from ..ptype import vetiver_create_ptype
34
import numpy as np
@@ -9,7 +10,7 @@
910
torch_exists = False
1011

1112

12-
class TorchHandler:
13+
class TorchHandler(VetiverHandler):
1314
"""Handler class for creating VetiverModels with torch.
1415
1516
Parameters
@@ -18,8 +19,7 @@ class TorchHandler:
1819
a trained torch model
1920
"""
2021
def __init__(self, model, ptype_data):
21-
self.model = model
22-
self.ptype_data = ptype_data
22+
super().__init__(model, ptype_data)
2323

2424
def create_description(self):
2525
"""Create description for torch model
Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from vetiver.handlers.base import VetiverHandler
12
from ..ptype import vetiver_create_ptype
2-
import pandas as pd
33
from ..meta import vetiver_meta
4+
5+
import pandas as pd
46
import sklearn
57

6-
class SKLearnHandler:
8+
class SKLearnHandler(VetiverHandler):
79
"""Handler class for creating VetiverModels with sklearn.
810
911
Parameters
@@ -13,8 +15,7 @@ class SKLearnHandler:
1315
"""
1416

1517
def __init__(self, model, ptype_data):
16-
self.model = model
17-
self.ptype_data = ptype_data
18+
super().__init__(model, ptype_data)
1819

1920
def create_description(self):
2021
"""Create description for sklearn model

vetiver/tests/test_no_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import pytest
22
from vetiver import VetiverModel, mock
3+
from vetiver import InvalidModelError
34

45

56
def test_not_implemented_error():
67
X, y = mock.get_mock_data()
78

8-
with pytest.raises(NotImplementedError):
9+
with pytest.raises(InvalidModelError):
910
VetiverModel(
1011
model=y,
1112
ptype_data=X,

vetiver/vetiver_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from vetiver.handlers._interface import create_translator
1+
from vetiver.handlers._interface import create_handler
22

33

44
class NoModelAvailableError(Exception):
@@ -44,7 +44,7 @@ def __init__(
4444
metadata: dict = None,
4545
**kwargs
4646
):
47-
translator = create_translator(model, ptype_data)
47+
translator = create_handler(model, ptype_data)
4848

4949
self.model = model
5050
self.ptype = translator.ptype()

0 commit comments

Comments
 (0)