|
1 | | -from typing import Any, Callable, Dict, List, Union |
| 1 | +from typing import Callable, List, Union |
2 | 2 | from urllib.parse import urljoin |
3 | 3 |
|
| 4 | +import re |
4 | 5 | import httpx |
5 | 6 | import pandas as pd |
6 | 7 | import requests |
7 | 8 | import uvicorn |
8 | 9 | from fastapi import FastAPI, Request, testclient |
| 10 | +from fastapi.exceptions import RequestValidationError |
9 | 11 | from fastapi.openapi.utils import get_openapi |
10 | 12 | from fastapi.responses import HTMLResponse, RedirectResponse |
| 13 | +from fastapi.responses import PlainTextResponse |
11 | 14 | from warnings import warn |
12 | 15 |
|
13 | 16 | from .utils import _jupyter_nb |
14 | 17 | from .vetiver_model import VetiverModel |
15 | 18 | from .meta import VetiverMeta |
| 19 | +from .helpers import api_data_to_frame, response_to_frame |
16 | 20 |
|
17 | 21 |
|
18 | 22 | class VetiverAPI: |
@@ -138,6 +142,10 @@ async def rapidoc(): |
138 | 142 | </html> |
139 | 143 | """ |
140 | 144 |
|
| 145 | + @app.exception_handler(RequestValidationError) |
| 146 | + async def validation_exception_handler(request, exc): |
| 147 | + return PlainTextResponse(str(exc), status_code=422) |
| 148 | + |
141 | 149 | return app |
142 | 150 |
|
143 | 151 | def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw): |
@@ -167,26 +175,26 @@ def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw): |
167 | 175 | if self.check_prototype is True: |
168 | 176 |
|
169 | 177 | @self.app.post(urljoin("/", endpoint_name), name=endpoint_name) |
170 | | - async def custom_endpoint( |
171 | | - input_data: Union[self.model.prototype, List[self.model.prototype]] |
172 | | - ): |
| 178 | + async def custom_endpoint(input_data: List[self.model.prototype]): |
| 179 | + _to_frame = api_data_to_frame(input_data) |
| 180 | + predictions = endpoint_fx(_to_frame, **kw) |
173 | 181 |
|
174 | | - if isinstance(input_data, List): |
175 | | - served_data = _batch_data(input_data) |
| 182 | + if isinstance(predictions, List): |
| 183 | + return {endpoint_name: predictions} |
176 | 184 | else: |
177 | | - served_data = _prepare_data(input_data) |
178 | | - |
179 | | - new = endpoint_fx(served_data, **kw) |
180 | | - return {endpoint_name: new.tolist()} |
| 185 | + return predictions |
181 | 186 |
|
182 | 187 | else: |
183 | 188 |
|
184 | 189 | @self.app.post(urljoin("/", endpoint_name)) |
185 | 190 | async def custom_endpoint(input_data: Request): |
186 | 191 | served_data = await input_data.json() |
187 | | - new = endpoint_fx(served_data, **kw) |
| 192 | + predictions = endpoint_fx(served_data, **kw) |
188 | 193 |
|
189 | | - return {endpoint_name: new.tolist()} |
| 194 | + if isinstance(predictions, List): |
| 195 | + return {endpoint_name: predictions} |
| 196 | + else: |
| 197 | + return predictions |
190 | 198 |
|
191 | 199 | def run(self, port: int = 8000, host: str = "127.0.0.1", **kw): |
192 | 200 | """ |
@@ -261,46 +269,28 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw) -> pd.Da |
261 | 269 | # TO DO: dispatch |
262 | 270 |
|
263 | 271 | if isinstance(data, pd.DataFrame): |
264 | | - data_json = data.to_json(orient="records") |
265 | | - response = requester.post(endpoint, data=data_json, **kw) |
| 272 | + response = requester.post( |
| 273 | + endpoint, data=data.to_json(orient="records"), **kw |
| 274 | + ) # TO DO: httpx deprecating data in favor of content for TestClient |
266 | 275 | elif isinstance(data, pd.Series): |
267 | | - data_dict = data.to_json() |
268 | | - response = requester.post(endpoint, data=data_dict, **kw) |
| 276 | + response = requester.post(endpoint, json=[data.to_dict()], **kw) |
269 | 277 | elif isinstance(data, dict): |
270 | | - response = requester.post(endpoint, json=data, **kw) |
| 278 | + response = requester.post(endpoint, json=[data], **kw) |
271 | 279 | else: |
272 | 280 | response = requester.post(endpoint, json=data, **kw) |
273 | 281 |
|
274 | 282 | try: |
275 | 283 | response.raise_for_status() |
276 | 284 | except (requests.exceptions.HTTPError, httpx.HTTPStatusError) as e: |
277 | 285 | if response.status_code == 422: |
278 | | - raise TypeError( |
279 | | - f"Predict expects DataFrame, Series, or dict. Given type is {type(data)}" |
280 | | - ) |
| 286 | + raise TypeError(re.sub(r"\n", ": ", response.text)) |
281 | 287 | raise requests.exceptions.HTTPError( |
282 | 288 | f"Could not obtain data from endpoint with error: {e}" |
283 | 289 | ) |
284 | 290 |
|
285 | | - response_df = pd.DataFrame.from_dict(response.json()) |
286 | | - |
287 | | - return response_df |
288 | | - |
289 | | - |
290 | | -def _prepare_data(pred_data: Dict[str, Any]) -> List[Any]: |
291 | | - served_data = [] |
292 | | - for key, value in pred_data: |
293 | | - served_data.append(value) |
294 | | - return served_data |
295 | | - |
296 | | - |
297 | | -def _batch_data(pred_data: List[Any]) -> pd.DataFrame: |
298 | | - columns = pred_data[0].dict().keys() |
299 | | - |
300 | | - data = [line.dict() for line in pred_data] |
| 291 | + response_frame = response_to_frame(response) |
301 | 292 |
|
302 | | - served_data = pd.DataFrame(data, columns=columns) |
303 | | - return served_data |
| 293 | + return response_frame |
304 | 294 |
|
305 | 295 |
|
306 | 296 | def vetiver_endpoint(url: str = "http://127.0.0.1:8000/predict") -> str: |
|
0 commit comments