|
15 | 15 | from fastapi.exceptions import RequestValidationError |
16 | 16 | from fastapi.openapi.utils import get_openapi |
17 | 17 | from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse |
18 | | - |
19 | 18 | from .helpers import api_data_to_frame, response_to_frame |
| 19 | +from .handlers.sklearn import SKLearnHandler |
20 | 20 | from .meta import VetiverMeta |
21 | 21 | from .utils import _jupyter_nb, get_workbench_path |
22 | 22 | from .vetiver_model import VetiverModel |
| 23 | +from .types import SklearnPredictionTypes |
23 | 24 |
|
24 | 25 |
|
25 | 26 | class VetiverAPI: |
@@ -111,7 +112,6 @@ async def startup_event(): |
111 | 112 |
|
112 | 113 | @app.get("/", include_in_schema=False) |
113 | 114 | def docs_redirect(): |
114 | | - |
115 | 115 | redirect = "__docs__" |
116 | 116 |
|
117 | 117 | return RedirectResponse(redirect) |
@@ -200,65 +200,94 @@ async def validation_exception_handler(request, exc): |
200 | 200 |
|
201 | 201 | return app |
202 | 202 |
|
203 | | - def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw): |
204 | | - """Create new POST endpoint that is aware of model input data |
| 203 | + def vetiver_post( |
| 204 | + self, |
| 205 | + endpoint_fx: Union[Callable, SklearnPredictionTypes], |
| 206 | + endpoint_name: str = None, |
| 207 | + **kw, |
| 208 | + ): |
| 209 | + """Define a new POST endpoint that utilizes the model's input data. |
205 | 210 |
|
206 | 211 | Parameters |
207 | 212 | ---------- |
208 | | - endpoint_fx : typing.Callable |
209 | | - Custom function to be run at endpoint |
| 213 | + endpoint_fx |
| 214 | + : Union[typing.Callable, Literal["predict", "predict_proba", "predict_log_proba"]] |
| 215 | + A callable function that specifies the custom logic to execute when the |
| 216 | + endpoint is called. This function should take input data (e.g., a DataFrame |
| 217 | + or dictionary) and return the desired output(e.g., predictions or transformed |
| 218 | + data). For scikit-learn models, endpoint_fx can also be one of "predict", |
| 219 | + "predict_proba", or "predict_log_proba" if the model supports these methods. |
| 220 | +
|
210 | 221 | endpoint_name : str |
211 | | - Name of endpoint |
| 222 | + The name of the endpoint to be created. |
212 | 223 |
|
213 | 224 | Examples |
214 | 225 | ------- |
215 | | - ```{python} |
| 226 | + ```python |
216 | 227 | from vetiver import mock, VetiverModel, VetiverAPI |
217 | 228 | X, y = mock.get_mock_data() |
218 | 229 | model = mock.get_mock_model().fit(X, y) |
219 | 230 |
|
220 | | - v = VetiverModel(model = model, model_name = "model", prototype_data = X) |
221 | | - v_api = VetiverAPI(model = v, check_prototype = True) |
| 231 | + v = VetiverModel(model=model, model_name="model", prototype_data=X) |
| 232 | + v_api = VetiverAPI(model=v, check_prototype=True) |
222 | 233 |
|
223 | 234 | def sum_values(x): |
224 | 235 | return x.sum() |
| 236 | +
|
225 | 237 | v_api.vetiver_post(sum_values, "sums") |
226 | 238 | ``` |
227 | 239 | """ |
228 | | - if not endpoint_name: |
229 | | - endpoint_name = endpoint_fx.__name__ |
230 | 240 |
|
231 | | - if endpoint_fx.__doc__ is not None: |
232 | | - api_desc = dedent(endpoint_fx.__doc__) |
233 | | - else: |
234 | | - api_desc = None |
| 241 | + if not isinstance(endpoint_fx, Callable): |
| 242 | + if endpoint_fx not in ["predict", "predict_proba", "predict_log_proba"]: |
| 243 | + raise ValueError( |
| 244 | + f""" |
| 245 | + Prediction type {endpoint_fx} not available. |
| 246 | + Available prediction types: {SklearnPredictionTypes} |
| 247 | + """ |
| 248 | + ) |
| 249 | + if not isinstance(self.model.handler_predict.__self__, SKLearnHandler): |
| 250 | + raise ValueError( |
| 251 | + """ |
| 252 | + The 'endpoint_fx' parameter can only be a |
| 253 | + string when using scikit-learn models. |
| 254 | + """ |
| 255 | + ) |
| 256 | + self.vetiver_post( |
| 257 | + self.model.handler_predict, |
| 258 | + endpoint_fx, |
| 259 | + check_prototype=self.check_prototype, |
| 260 | + prediction_type=endpoint_fx, |
| 261 | + ) |
| 262 | + return |
235 | 263 |
|
236 | | - if self.check_prototype is True: |
| 264 | + endpoint_name = endpoint_name or endpoint_fx.__name__ |
| 265 | + endpoint_doc = dedent(endpoint_fx.__doc__) if endpoint_fx.__doc__ else None |
237 | 266 |
|
238 | | - @self.app.post( |
239 | | - urljoin("/", endpoint_name), |
240 | | - name=endpoint_name, |
241 | | - description=api_desc, |
242 | | - ) |
243 | | - async def custom_endpoint(input_data: List[self.model.prototype]): |
244 | | - _to_frame = api_data_to_frame(input_data) |
245 | | - predictions = endpoint_fx(_to_frame, **kw) |
246 | | - if isinstance(predictions, List): |
247 | | - return {endpoint_name: predictions} |
248 | | - else: |
249 | | - return predictions |
| 267 | + # this must be split up this way to preserve the correct type hints for |
| 268 | + # the input_data schema validation via Pydantic + FastAPI |
| 269 | + input_data_type = ( |
| 270 | + List[self.model.prototype] if self.check_prototype else Request |
| 271 | + ) |
250 | 272 |
|
251 | | - else: |
| 273 | + @self.app.post( |
| 274 | + urljoin("/", endpoint_name), |
| 275 | + name=endpoint_name, |
| 276 | + description=endpoint_doc, |
| 277 | + ) |
| 278 | + async def custom_endpoint(input_data: input_data_type): |
252 | 279 |
|
253 | | - @self.app.post(urljoin("/", endpoint_name)) |
254 | | - async def custom_endpoint(input_data: Request): |
255 | | - served_data = await input_data.json() |
256 | | - predictions = endpoint_fx(served_data, **kw) |
| 280 | + served_data = ( |
| 281 | + api_data_to_frame(input_data) |
| 282 | + if self.check_prototype |
| 283 | + else await input_data.json() |
| 284 | + ) |
| 285 | + predictions = endpoint_fx(served_data, **kw) |
257 | 286 |
|
258 | | - if isinstance(predictions, List): |
259 | | - return {endpoint_name: predictions} |
260 | | - else: |
261 | | - return predictions |
| 287 | + if isinstance(predictions, List): |
| 288 | + return {endpoint_name: predictions} |
| 289 | + else: |
| 290 | + return predictions |
262 | 291 |
|
263 | 292 | def run(self, port: int = 8000, host: str = "127.0.0.1", quiet_open=False, **kw): |
264 | 293 | """ |
|
0 commit comments