1- from typing import Callable , List , Union , Any , Dict
1+ from typing import Any , Callable , Dict , List , Union
22from urllib .parse import urljoin
33
44import httpx
88from fastapi import FastAPI , Request , testclient
99from fastapi .openapi .utils import get_openapi
1010from fastapi .responses import HTMLResponse , RedirectResponse
11+ from warnings import warn
1112
1213from .utils import _jupyter_nb
1314from .vetiver_model import VetiverModel
@@ -20,8 +21,10 @@ class VetiverAPI:
2021 ----------
2122 model : VetiverModel
2223 Model to be deployed in API
23- check_ptype : bool
24+ check_prototype : bool
2425 Determine if data prototype should be enforced
26+ check_ptype : bool
27+ Deprecated in favor of check_prototype
2528 app_factory :
2629 Type of API to be deployed
2730
@@ -31,22 +34,32 @@ class VetiverAPI:
3134 >>> X, y = vt.get_mock_data()
3235 >>> model = vt.get_mock_model().fit(X, y)
3336 >>> v = vt.VetiverModel(model = model, model_name = "my_model", prototype_data = X)
34- >>> v_api = vt.VetiverAPI(model = v, check_ptype = True)
37+ >>> v_api = vt.VetiverAPI(model = v, check_prototype = True)
3538 """
3639
3740 app = None
3841
3942 def __init__ (
4043 self ,
4144 model : VetiverModel ,
42- check_ptype : bool = True ,
45+ check_prototype : bool = True ,
46+ check_ptype : bool = None ,
4347 app_factory = FastAPI ,
4448 ) -> None :
4549 self .model = model
46- self .check_ptype = check_ptype
4750 self .app_factory = app_factory
4851 self .app = app_factory ()
4952
53+ if check_ptype is not None :
54+ check_prototype = check_ptype
55+ warn (
56+ "argument for checking input data prototype has changed to "
57+ "check_prototype, from check_ptype" ,
58+ DeprecationWarning ,
59+ stacklevel = 2 ,
60+ )
61+ self .check_prototype = check_prototype
62+
5063 self ._init_app ()
5164
5265 def _init_app (self ):
@@ -71,7 +84,7 @@ async def ping():
7184 return {"ping" : "pong" }
7285
7386 self .vetiver_post (
74- self .model .handler_predict , "predict" , check_ptype = self .check_ptype
87+ self .model .handler_predict , "predict" , check_prototype = self .check_prototype
7588 )
7689
7790 @app .get ("/__docs__" , response_class = HTMLResponse , include_in_schema = False )
@@ -130,15 +143,15 @@ def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
130143 >>> X, y = vt.get_mock_data()
131144 >>> model = vt.get_mock_model().fit(X, y)
132145 >>> v = vt.VetiverModel(model = model, model_name = "model", prototype_data = X)
133- >>> v_api = vt.VetiverAPI(model = v, check_ptype = True)
146+ >>> v_api = vt.VetiverAPI(model = v, check_prototype = True)
134147 >>> def sum_values(x):
135148 ... return x.sum()
136149 >>> v_api.vetiver_post(sum_values, "sums")
137150 """
138151 if not endpoint_name :
139152 endpoint_name = endpoint_fx .__name__
140153
141- if self .check_ptype is True :
154+ if self .check_prototype is True :
142155
143156 @self .app .post (urljoin ("/" , endpoint_name ), name = endpoint_name )
144157 async def custom_endpoint (
@@ -179,7 +192,7 @@ def run(self, port: int = 8000, host: str = "127.0.0.1", **kw):
179192 >>> X, y = vt.get_mock_data()
180193 >>> model = vt.get_mock_model().fit(X, y)
181194 >>> v = vt.VetiverModel(model = model, model_name = "model", prototype_data = X)
182- >>> v_api = vt.VetiverAPI(model = v, check_ptype = True)
195+ >>> v_api = vt.VetiverAPI(model = v, check_prototype = True)
183196 >>> v_api.run() # doctest: +SKIP
184197 """
185198 _jupyter_nb ()
0 commit comments