Skip to content

Commit b2ab951

Browse files
committed
better testing
1 parent e4e1596 commit b2ab951

5 files changed

Lines changed: 74 additions & 6 deletions

File tree

vetiver/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
from .write_docker import *
1616
from .write_fastapi import *
1717
from .handlers._interface import create_handler, InvalidModelError
18+
from .handlers.base import VetiverHandler
1819
from .handlers.sklearn import SKLearnHandler
1920
from .handlers.torch import TorchHandler

vetiver/handlers/_interface.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from vetiver.handlers import torch, sklearn, base
22
from functools import singledispatch
33

4+
from vetiver.ptype import vetiver_create_ptype
5+
46
class InvalidModelError(Exception):
57
"""
68
Throw an error if `model` is not
@@ -79,10 +81,16 @@ def create_handler(model, ptype_data):
7981
>>> handler.create_description()
8082
Scikit-learn <class 'sklearn.dummy.DummyRegressor'> model
8183
"""
84+
8285
raise InvalidModelError(message=CREATE_PTYPE_TPL.format(_model_type=type(model)))
8386

8487
create_handler.register(sklearn.SKLearnHandler.base_class, sklearn.SKLearnHandler)
8588

8689
create_handler.register(torch.TorchHandler.base_class, torch.TorchHandler)
8790

88-
#create_handler.register(base.VetiverHandler, lambda model: model)
91+
@create_handler.register
92+
def _(model:base.VetiverHandler, ptype_data):
93+
if model.ptype_data is None and ptype_data is not None:
94+
model.ptype_data = ptype_data
95+
96+
return model

vetiver/handlers/base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
from ..ptype import vetiver_create_ptype
22
from ..meta import vetiver_meta
3+
from abc import ABCMeta
34

4-
class VetiverHandler:
5+
class VetiverHandler(metaclass=ABCMeta):
56
"""Base handler class for creating VetiverModel of different type.
67
78
Parameters
89
----------
910
model :
1011
a trained model
11-
ptype_data:
12+
ptype_data :
1213
An object with information (data) whose layout is to be determined.
1314
"""
1415

1516
def __init__(self, model, ptype_data):
1617
self.model = model
1718
self.ptype_data = ptype_data
1819

19-
def __class__(self):
20-
...
2120

2221
def create_description(self):
2322
"""Create description for model"""
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import pytest
2+
import sklearn
3+
import pydantic
4+
import pandas as pd
5+
6+
from vetiver import mock, VetiverModel, VetiverHandler
7+
8+
9+
class CustomHandler(VetiverHandler):
10+
def __init__(self, model, ptype_data):
11+
super().__init__(model, ptype_data)
12+
13+
def handler_predict(self, input_data, check_ptype):
14+
if check_ptype == True:
15+
if isinstance(input_data, pd.DataFrame):
16+
prediction = self.model.predict(input_data)
17+
else:
18+
prediction = self.model.predict([input_data])
19+
else:
20+
if not isinstance(input_data, list):
21+
input_data = [input_data.split(",")] # user delimiter ?
22+
prediction = self.model.predict(input_data)
23+
24+
return prediction
25+
26+
27+
def test_custom_vetiver_model():
28+
X, y = mock.get_mock_data()
29+
model = mock.get_mock_model().fit(X, y)
30+
custom_handler = CustomHandler(model, X)
31+
32+
v = VetiverModel(
33+
model=custom_handler,
34+
ptype_data=X,
35+
model_name="my_model",
36+
versioned=None,
37+
description="A regression model for testing purposes",
38+
)
39+
40+
assert v.description == "A regression model for testing purposes"
41+
assert isinstance(v.model, sklearn.dummy.DummyRegressor)
42+
assert isinstance(v.ptype.construct(), pydantic.BaseModel)
43+
44+
45+
def test_custom_vetiver_model():
46+
X, y = mock.get_mock_data()
47+
model = mock.get_mock_model().fit(X, y)
48+
custom_handler = CustomHandler(model, None)
49+
50+
v = VetiverModel(
51+
model=custom_handler,
52+
ptype_data=X,
53+
model_name="my_model",
54+
versioned=None,
55+
description="A regression model for testing purposes",
56+
)
57+
58+
assert v.description == "A regression model for testing purposes"
59+
assert isinstance(v.model, sklearn.dummy.DummyRegressor)
60+
assert isinstance(v.ptype.construct(), pydantic.BaseModel)

vetiver/vetiver_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
):
4747
translator = create_handler(model, ptype_data)
4848

49-
self.model = model
49+
self.model = translator.model
5050
self.ptype = translator.ptype()
5151
self.model_name = model_name
5252
self.description = description if description else translator.create_description()

0 commit comments

Comments
 (0)