@@ -81,14 +81,18 @@ async def rapidoc():
8181 if self .check_ptype == True :
8282
8383 @app .post ("/predict/" )
84- async def prediction (input_data : Union [self .model .ptype , List [self .model .ptype ]]):
85-
84+ async def prediction (
85+ input_data : Union [self .model .ptype , List [self .model .ptype ]]
86+ ):
87+
8688 if isinstance (input_data , List ):
8789 served_data = _batch_data (input_data )
8890 else :
8991 served_data = _prepare_data (input_data )
9092
91- y = self .model .handler_predict (served_data , check_ptype = self .check_ptype )
93+ y = self .model .handler_predict (
94+ served_data , check_ptype = self .check_ptype
95+ )
9296
9397 return {"prediction" : y .tolist ()}
9498
@@ -107,7 +111,7 @@ def vetiver_post(
107111 self , endpoint_fx : Callable , endpoint_name : str = "custom_endpoint"
108112 ):
109113 """Create new POST endpoint
110-
114+
111115 Parameters
112116 ----------
113117 endpoint_fx : typing.Callable
@@ -138,12 +142,12 @@ async def custom_endpoint(input_data: Request):
138142 return {endpoint_name : new .tolist ()}
139143
140144 def run (self ):
141- """Start API
142- """
145+ """Start API"""
143146 _jupyter_nb ()
144147 uvicorn .run (self .app , port = self .port , host = self .host )
145148
146- def predict (endpoint , data : dict ):
149+
150+ def predict (endpoint , data : dict , ** kw ):
147151 """Make a prediction from model endpoint
148152
149153 Parameters
@@ -158,7 +162,11 @@ def predict(endpoint, data: dict):
158162 dict
159163 Key: endpoint_name Value: Output of endpoint_fx, in list format
160164 """
161- response = requests .post (endpoint , json = data )
165+ if isinstance (data , pd .DataFrame ):
166+ data = data .to_json (orient = "records" )
167+ response = requests .post (endpoint , data = data , ** kw )
168+ else :
169+ response = requests .post (endpoint , json = data , ** kw )
162170
163171 return response .json ()
164172
@@ -169,28 +177,27 @@ def _prepare_data(pred_data):
169177 served_data .append (value )
170178 return served_data
171179
180+
172181def _batch_data (pred_data ):
173182 columns = pred_data [0 ].dict ().keys ()
174183
175184 data = [line .dict () for line in pred_data ]
176- print (data )
177185
178186 served_data = pd .DataFrame (data , columns = columns )
179187 return served_data
180188
181189
182-
183190def vetiver_endpoint (url = "http://127.0.0.1:8000/predict" ):
184191 """Wrap url
185192
186- Parameters
187- ----------
188- url : str
189- URI path to endpoint
193+ Parameters
194+ ----------
195+ url : str
196+ URI path to endpoint
190197
191- Returns
192- -------
193- url : str
194- URI path to endpoint
195- """
198+ Returns
199+ -------
200+ url : str
201+ URI path to endpoint
202+ """
196203 return url
0 commit comments