1616from .utils import _jupyter_nb
1717from .vetiver_model import VetiverModel
1818from .meta import VetiverMeta
19- from .helpers import api_data_to_frame
19+ from .helpers import api_data_to_frame , response_to_frame
2020
2121
2222class VetiverAPI :
@@ -178,7 +178,11 @@ def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
178178 async def custom_endpoint (input_data : List [self .model .prototype ]):
179179 _to_frame = api_data_to_frame (input_data )
180180 predictions = endpoint_fx (_to_frame , ** kw )
181- return {endpoint_name : predictions }
181+
182+ if isinstance (predictions , List ):
183+ return {endpoint_name : predictions }
184+ else :
185+ return predictions
182186
183187 else :
184188
@@ -187,7 +191,10 @@ async def custom_endpoint(input_data: Request):
187191 served_data = await input_data .json ()
188192 predictions = endpoint_fx (served_data , ** kw )
189193
190- return {endpoint_name : predictions }
194+ if isinstance (predictions , List ):
195+ return {endpoint_name : predictions }
196+ else :
197+ return predictions
191198
192199 def run (self , port : int = 8000 , host : str = "127.0.0.1" , ** kw ):
193200 """
@@ -262,7 +269,9 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw) -> pd.Da
262269 # TO DO: dispatch
263270
264271 if isinstance (data , pd .DataFrame ):
265- response = requester .post (endpoint , data = data .to_json (orient = "records" ), ** kw )
272+ response = requester .post (
273+ endpoint , content = data .to_json (orient = "records" ), ** kw
274+ )
266275 elif isinstance (data , pd .Series ):
267276 response = requester .post (endpoint , json = [data .to_dict ()], ** kw )
268277 elif isinstance (data , dict ):
@@ -279,9 +288,9 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw) -> pd.Da
279288 f"Could not obtain data from endpoint with error: { e } "
280289 )
281290
282- response_df = pd . DataFrame . from_dict (response . json () )
291+ response_frame = response_to_frame (response )
283292
284- return response_df
293+ return response_frame
285294
286295
287296def vetiver_endpoint (url : str = "http://127.0.0.1:8000/predict" ) -> str :
0 commit comments