|
1 | | -from typing import Any, Callable, Dict, List, Union |
| 1 | +from typing import Callable, List, Union |
2 | 2 | from urllib.parse import urljoin |
3 | 3 |
|
4 | 4 | import httpx |
@@ -167,15 +167,7 @@ def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw): |
167 | 167 | if self.check_prototype is True: |
168 | 168 |
|
169 | 169 | @self.app.post(urljoin("/", endpoint_name), name=endpoint_name) |
170 | | - async def custom_endpoint( |
171 | | - input_data: Union[self.model.prototype, List[self.model.prototype]] |
172 | | - ): |
173 | | - |
174 | | - # if isinstance(input_data, List): |
175 | | - # served_data = _batch_data(input_data) |
176 | | - # else: |
177 | | - # served_data = _prepare_data(input_data) |
178 | | - |
| 170 | + async def custom_endpoint(input_data: List[self.model.prototype]): |
179 | 171 | new = endpoint_fx(input_data, **kw) |
180 | 172 | return {endpoint_name: new.tolist()} |
181 | 173 |
|
@@ -261,13 +253,13 @@ def predict(endpoint, data: Union[dict, pd.DataFrame, pd.Series], **kw) -> pd.Da |
261 | 253 | # TO DO: dispatch |
262 | 254 |
|
263 | 255 | if isinstance(data, pd.DataFrame): |
264 | | - data_json = data.to_json(orient="records") |
265 | | - response = requester.post(endpoint, data=data_json, **kw) |
| 256 | + response = requester.post(endpoint, data=data.to_json(orient="records"), **kw) |
266 | 257 | elif isinstance(data, pd.Series): |
267 | | - data_dict = data.to_json() |
268 | | - response = requester.post(endpoint, data=data_dict, **kw) |
269 | | - elif isinstance(data, dict): |
| 258 | + response = requester.post(endpoint, data=data.to_json(), **kw) |
| 259 | + elif isinstance(data, list): |
270 | 260 | response = requester.post(endpoint, json=data, **kw) |
| 261 | + elif isinstance(data, dict): |
| 262 | + response = requester.post(endpoint, json=[data], **kw) |
271 | 263 | else: |
272 | 264 | response = requester.post(endpoint, json=data, **kw) |
273 | 265 |
|
|
0 commit comments