@@ -169,7 +169,7 @@ def _custom_openapi(self):
169169 self .app .openapi_schema = openapi_schema
170170 return self .app .openapi_schema
171171
172- def predict (endpoint , data : dict , ** kw ):
172+ def predict (endpoint , data : Union [ dict , pd . DataFrame , pd . Series ] , ** kw ):
173173 """Make a prediction from model endpoint
174174
175175 Parameters
@@ -184,13 +184,17 @@ def predict(endpoint, data: dict, **kw):
184184 dict
185185 Key: endpoint_name Value: Output of endpoint_fx, in list format
186186 """
187- if isinstance (data , pd .DataFrame ):
188- data = data .to_json (orient = "records" )
189- response = requests .post (endpoint , data = data , ** kw )
190- else :
187+ if isinstance (data , ( pd .DataFrame , pd . Series ) ):
188+ data_json = data .to_json (orient = "records" )
189+ response = requests .post (endpoint , data = data_json , ** kw )
190+ elif isinstance ( data , dict ) :
191191 response = requests .post (endpoint , json = data , ** kw )
192+ else :
193+ raise TypeError (f"Accepted data types are dictionary or DataFrame, given type is { type (data )} \n { data } " )
194+
195+ response_df = pd .DataFrame .from_dict (response .json ())
192196
193- return response . json ()
197+ return response_df
194198
195199
196200def _prepare_data (pred_data ):
0 commit comments