1- from abc import ABCMeta
1+ from vetiver .handlers import base
2+ from functools import singledispatch
3+ from contextlib import suppress
24
35from ..ptype import vetiver_create_ptype
46from ..meta import _model_meta
57
68
7- class VetiverHandler (metaclass = ABCMeta ):
9+ class InvalidModelError (Exception ):
10+ """
11+ Throw an error if `model` is not registered.
12+ """
13+
14+ def __init__ (
15+ self ,
16+ message = "The `model` argument must be a scikit-learn or torch model." ,
17+ ):
18+ self .message = message
19+ super ().__init__ (self .message )
20+
21+
22+ @singledispatch
23+ def create_handler (model , ptype_data ):
24+ """check for model type to handle prediction
25+
26+ Parameters
27+ ----------
28+ model: object
29+ Description of parameter `x`.
30+ ptype_data : object
31+ An object with information (data) whose layout is to be determined.
32+
33+ Returns
34+ -------
35+ handler
36+ Handler class for specified model type
37+
38+
39+ Examples
40+ --------
41+ >>> import vetiver
42+ >>> X, y = vetiver.mock.get_mock_data()
43+ >>> model = vetiver.mock.get_mock_model()
44+ >>> handler = vetiver.create_handler(model, X)
45+ >>> handler.describe()
46+ "Scikit-learn <class 'sklearn.dummy.DummyRegressor'> model"
47+ """
48+
49+ raise InvalidModelError (
50+ "Model must be an sklearn or torch model, or a \
51+ custom handler must be used. See the docs for more info on custom handlers. \
52+ https://rstudio.github.io/vetiver-python/advancedusage/custom_handler.html"
53+ )
54+
55+
56+ class VetiverHandler :
857 """Base handler class for creating VetiverModel of different type.
958
1059 Parameters
@@ -15,6 +64,16 @@ class VetiverHandler(metaclass=ABCMeta):
1564 An object with information (data) whose layout is to be determined.
1665 """
1766
67+ @classmethod
68+ def __init_subclass__ (cls , ** kwargs ):
69+ super ().__init_subclass__ (** kwargs )
70+ with suppress (AttributeError , NameError ):
71+ create_handler .register (cls .base_class , cls )
72+
73+ def __new__ (cls , value = None ):
74+ implementation_cls = create_handler .registry [type (value )]
75+ return super ().__new__ (implementation_cls )
76+
1877 def __init__ (self , model , ptype_data ):
1978 self .model = model
2079 self .ptype_data = ptype_data
@@ -79,3 +138,11 @@ def handler_predict(self, input_data, check_ptype):
79138 Prediction from model
80139 """
81140 ...
141+
142+
143+ @create_handler .register
144+ def _ (model : base .VetiverHandler , ptype_data ):
145+ if model .ptype_data is None and ptype_data is not None :
146+ model .ptype_data = ptype_data
147+
148+ return model
0 commit comments