|
5 | 5 | import pandas as pd |
6 | 6 | import requests |
7 | 7 | import uvicorn |
8 | | -from fastapi import FastAPI, Request, testclient |
| 8 | +from fastapi import FastAPI, Request, testclient, status |
| 9 | +from fastapi.encoders import jsonable_encoder |
| 10 | +from fastapi.exceptions import RequestValidationError |
9 | 11 | from fastapi.openapi.utils import get_openapi |
10 | | -from fastapi.responses import HTMLResponse, RedirectResponse |
| 12 | +from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse |
| 13 | +from fastapi.responses import PlainTextResponse |
| 14 | +from starlette.exceptions import HTTPException as StarletteHTTPException |
11 | 15 | from warnings import warn |
12 | 16 |
|
13 | 17 | from .utils import _jupyter_nb |
@@ -138,6 +142,19 @@ async def rapidoc(): |
138 | 142 | </html> |
139 | 143 | """ |
140 | 144 |
|
| 145 | + @app.exception_handler(RequestValidationError) |
| 146 | + async def validation_exception_handler( |
| 147 | + request: Request, exc: RequestValidationError |
| 148 | + ): |
| 149 | + return JSONResponse( |
| 150 | + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, |
| 151 | + content=jsonable_encoder({"detail": exc.errors(), "body": exc.body}), |
| 152 | + ) |
| 153 | + |
| 154 | + @app.exception_handler(StarletteHTTPException) |
| 155 | + async def http_exception_handler(request, exc): |
| 156 | + return PlainTextResponse(str(exc.detail), status_code=exc.status_code) |
| 157 | + |
141 | 158 | return app |
142 | 159 |
|
143 | 160 | def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw): |
@@ -266,7 +283,7 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw) -> pd.Da |
266 | 283 | except (requests.exceptions.HTTPError, httpx.HTTPStatusError) as e: |
267 | 284 | if response.status_code == 422: |
268 | 285 | raise TypeError( |
269 | | - f"Predict expects DataFrame, Series, or dict. Given type is {type(data)}" |
| 286 | + PlainTextResponse(str(response), status_code=response.status_code) |
270 | 287 | ) |
271 | 288 | raise requests.exceptions.HTTPError( |
272 | 289 | f"Could not obtain data from endpoint with error: {e}" |
|
0 commit comments