1+ import pytest
2+ import sklearn
3+ import pydantic
4+ import pandas as pd
5+
6+ from vetiver import mock , VetiverModel , VetiverHandler
7+
8+
9+ class CustomHandler (VetiverHandler ):
10+ def __init__ (self , model , ptype_data ):
11+ super ().__init__ (model , ptype_data )
12+
13+ def handler_predict (self , input_data , check_ptype ):
14+ if check_ptype == True :
15+ if isinstance (input_data , pd .DataFrame ):
16+ prediction = self .model .predict (input_data )
17+ else :
18+ prediction = self .model .predict ([input_data ])
19+ else :
20+ if not isinstance (input_data , list ):
21+ input_data = [input_data .split ("," )] # user delimiter ?
22+ prediction = self .model .predict (input_data )
23+
24+ return prediction
25+
26+
27+ def test_custom_vetiver_model ():
28+ X , y = mock .get_mock_data ()
29+ model = mock .get_mock_model ().fit (X , y )
30+ custom_handler = CustomHandler (model , X )
31+
32+ v = VetiverModel (
33+ model = custom_handler ,
34+ ptype_data = X ,
35+ model_name = "my_model" ,
36+ versioned = None ,
37+ description = "A regression model for testing purposes" ,
38+ )
39+
40+ assert v .description == "A regression model for testing purposes"
41+ assert isinstance (v .model , sklearn .dummy .DummyRegressor )
42+ assert isinstance (v .ptype .construct (), pydantic .BaseModel )
43+
44+
45+ def test_custom_vetiver_model ():
46+ X , y = mock .get_mock_data ()
47+ model = mock .get_mock_model ().fit (X , y )
48+ custom_handler = CustomHandler (model , None )
49+
50+ v = VetiverModel (
51+ model = custom_handler ,
52+ ptype_data = X ,
53+ model_name = "my_model" ,
54+ versioned = None ,
55+ description = "A regression model for testing purposes" ,
56+ )
57+
58+ assert v .description == "A regression model for testing purposes"
59+ assert isinstance (v .model , sklearn .dummy .DummyRegressor )
60+ assert isinstance (v .ptype .construct (), pydantic .BaseModel )
0 commit comments