Skip to content

Commit c539ab8

Browse files
authored
Merge pull request #179 from rstudio/fastapi-docs
2 parents 77439fa + 93864f9 commit c539ab8

6 files changed

Lines changed: 27 additions & 9 deletions

File tree

vetiver/handlers/sklearn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ class SKLearnHandler(BaseHandler):
1717
pip_name = "scikit-learn"
1818

1919
def handler_predict(self, input_data, check_prototype):
20-
"""Generates method for /predict endpoint in VetiverAPI
20+
"""
21+
Generates method for /predict endpoint in VetiverAPI
2122
2223
The `handler_predict` function executes at each API call. Use this
2324
function for calling `predict()` and any other tasks that must be executed
@@ -30,7 +31,7 @@ def handler_predict(self, input_data, check_prototype):
3031
3132
Returns
3233
-------
33-
prediction
34+
prediction:
3435
Prediction from model
3536
"""
3637

vetiver/handlers/spacy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def construct_prototype(self):
5454
return prototype
5555

5656
def handler_predict(self, input_data, check_prototype):
57-
"""Generates method for /predict endpoint in VetiverAPI
57+
"""
58+
Generates method for /predict endpoint in VetiverAPI
5859
5960
The `handler_predict` function executes at each API call. Use this
6061
function for calling `predict()` and any other tasks that must be executed
@@ -63,7 +64,9 @@ def handler_predict(self, input_data, check_prototype):
6364
Parameters
6465
----------
6566
input_data:
66-
Test data
67+
Test data. The SpacyHandler expects an input of a 1 column DataFrame with
68+
the same column names as the prototype data, or column name "text" if no
69+
prototype was given.
6770
6871
Returns
6972
-------

vetiver/handlers/statsmodels.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ class StatsmodelsHandler(BaseHandler):
2323
pip_name = "statsmodels"
2424

2525
def handler_predict(self, input_data, check_prototype):
26-
"""Generates method for /predict endpoint in VetiverAPI
26+
"""
27+
Generates method for /predict endpoint in VetiverAPI
2728
2829
The `handler_predict` function executes at each API call. Use this
2930
function for calling `predict()` and any other tasks that must be executed

vetiver/handlers/torch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ class TorchHandler(BaseHandler):
2323
pip_name = "torch"
2424

2525
def handler_predict(self, input_data, check_prototype):
26-
"""Generates method for /predict endpoint in VetiverAPI
26+
"""
27+
Generates method for /predict endpoint in VetiverAPI
2728
2829
The `handler_predict` function executes at each API call. Use this
2930
function for calling `predict()` and any other tasks that must be executed

vetiver/handlers/xgboost.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ class XGBoostHandler(BaseHandler):
2323
pip_name = "xgboost"
2424

2525
def handler_predict(self, input_data, check_prototype):
26-
"""Generates method for /predict endpoint in VetiverAPI
26+
"""
27+
Generates method for /predict endpoint in VetiverAPI
2728
2829
The `handler_predict` function executes at each API call. Use this
2930
function for calling `predict()` and any other tasks that must be executed

vetiver/server.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from fastapi.openapi.utils import get_openapi
1212
from fastapi.responses import HTMLResponse, RedirectResponse
1313
from fastapi.responses import PlainTextResponse
14+
from textwrap import dedent
1415
from warnings import warn
1516

1617
from .utils import _jupyter_nb
@@ -105,10 +106,12 @@ def pin_url():
105106

106107
@app.get("/ping", include_in_schema=True)
107108
async def ping():
109+
"""Ping endpoint for health check"""
108110
return {"ping": "pong"}
109111

110112
@app.get("/metadata")
111113
async def get_metadata():
114+
"""Get metadata from model"""
112115
return self.model.metadata.to_dict()
113116

114117
self.vetiver_post(
@@ -183,13 +186,21 @@ def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
183186
if not endpoint_name:
184187
endpoint_name = endpoint_fx.__name__
185188

189+
if endpoint_fx.__doc__ is not None:
190+
api_desc = dedent(endpoint_fx.__doc__)
191+
else:
192+
api_desc = None
193+
186194
if self.check_prototype is True:
187195

188-
@self.app.post(urljoin("/", endpoint_name), name=endpoint_name)
196+
@self.app.post(
197+
urljoin("/", endpoint_name),
198+
name=endpoint_name,
199+
description=api_desc,
200+
)
189201
async def custom_endpoint(input_data: List[self.model.prototype]):
190202
_to_frame = api_data_to_frame(input_data)
191203
predictions = endpoint_fx(_to_frame, **kw)
192-
193204
if isinstance(predictions, List):
194205
return {endpoint_name: predictions}
195206
else:

0 commit comments

Comments
 (0)