Skip to content

Commit 5f28f61

Browse files
committed
add response_to_frame function
1 parent 2e99db2 commit 5f28f61

2 files changed

Lines changed: 20 additions & 22 deletions

File tree

vetiver/helpers.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,8 @@ def _dict(pred_data):
3333
return api_data_to_frame([pred_data])
3434

3535

36-
# possible other names
37-
# prototype_to_frame
38-
# prototype_to_data
39-
# prototype_to_dataframe
40-
# prototype_to_datatype
41-
# prototype_to_type
42-
# api_data_to_
43-
# json_to_
44-
# server_data_to_
45-
# transport_data_to_
46-
# request
47-
# query
48-
# transit
49-
# interchange/exchange
50-
# transfer
51-
# through
36+
def response_to_frame(response: dict) -> pd.DataFrame:
37+
38+
response_df = pd.DataFrame.from_dict(response.json())
39+
40+
return response_df

vetiver/server.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .utils import _jupyter_nb
1717
from .vetiver_model import VetiverModel
1818
from .meta import VetiverMeta
19-
from .helpers import api_data_to_frame
19+
from .helpers import api_data_to_frame, response_to_frame
2020

2121

2222
class VetiverAPI:
@@ -178,7 +178,11 @@ def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
178178
async def custom_endpoint(input_data: List[self.model.prototype]):
179179
_to_frame = api_data_to_frame(input_data)
180180
predictions = endpoint_fx(_to_frame, **kw)
181-
return {endpoint_name: predictions}
181+
182+
if isinstance(predictions, List):
183+
return {endpoint_name: predictions}
184+
else:
185+
return predictions
182186

183187
else:
184188

@@ -187,7 +191,10 @@ async def custom_endpoint(input_data: Request):
187191
served_data = await input_data.json()
188192
predictions = endpoint_fx(served_data, **kw)
189193

190-
return {endpoint_name: predictions}
194+
if isinstance(predictions, List):
195+
return {endpoint_name: predictions}
196+
else:
197+
return predictions
191198

192199
def run(self, port: int = 8000, host: str = "127.0.0.1", **kw):
193200
"""
@@ -262,7 +269,9 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw) -> pd.Da
262269
# TO DO: dispatch
263270

264271
if isinstance(data, pd.DataFrame):
265-
response = requester.post(endpoint, data=data.to_json(orient="records"), **kw)
272+
response = requester.post(
273+
endpoint, content=data.to_json(orient="records"), **kw
274+
)
266275
elif isinstance(data, pd.Series):
267276
response = requester.post(endpoint, json=[data.to_dict()], **kw)
268277
elif isinstance(data, dict):
@@ -279,9 +288,9 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw) -> pd.Da
279288
f"Could not obtain data from endpoint with error: {e}"
280289
)
281290

282-
response_df = pd.DataFrame.from_dict(response.json())
291+
response_frame = response_to_frame(response)
283292

284-
return response_df
293+
return response_frame
285294

286295

287296
def vetiver_endpoint(url: str = "http://127.0.0.1:8000/predict") -> str:

0 commit comments

Comments
 (0)