1111
1212
1313class StatsmodelsHandler (BaseHandler ):
14- """Handler class for creating VetiverModels with sklearn .
14+ """Handler class for creating VetiverModels with statsmodels .
1515
1616 Parameters
1717 ----------
1818 model : statsmodels
19- a trained sklearn model
19+ a trained and fit statsmodels model
2020 """
2121
2222 model_class = staticmethod (lambda : statsmodels .base .wrapper .ResultsWrapper )
@@ -25,7 +25,7 @@ def __init__(self, model, ptype_data):
2525 super ().__init__ (model , ptype_data )
2626
2727 def describe (self ):
28- """Create description for sklearn model"""
28+ """Create description for statsmodels model"""
2929 desc = f"Statsmodels { self .model .__class__ } model."
3030 return desc
3131
@@ -35,7 +35,7 @@ def construct_meta(
3535 url : str = None ,
3636 required_pkgs : list = [],
3737 ):
38- """Create metadata for sklearn model """
38+ """Create metadata for statsmodel """
3939 required_pkgs = required_pkgs + ["statsmodels" ]
4040 meta = _model_meta (user , version , url , required_pkgs )
4141
@@ -59,17 +59,9 @@ def handler_predict(self, input_data, check_ptype):
5959 Prediction from model
6060 """
6161
62- if check_ptype :
63- if isinstance (input_data , pd .DataFrame ):
64- prediction = self .model .predict (input_data )
65- else :
66- prediction = self .model .predict ([input_data ])
67-
68- # do not check ptype
69- else :
70- if not isinstance (input_data , list ):
71- input_data = [input_data .split ("," )] # user delimiter ?
72-
62+ if isinstance (input_data , pd .DataFrame ):
7363 prediction = self .model .predict (input_data )
64+ else :
65+ prediction = self .model .predict ([input_data ])
7466
7567 return prediction
0 commit comments