Skip to content

Commit d219be2

Browse files
authored
Merge pull request #22 from isabelizimm/predict-creds
add **kw to predict
2 parents 04eb320 + 6b31e8b commit d219be2

1 file changed

Lines changed: 26 additions & 19 deletions

File tree

vetiver/server.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,18 @@ async def rapidoc():
8181
if self.check_ptype == True:
8282

8383
@app.post("/predict/")
84-
async def prediction(input_data: Union[self.model.ptype, List[self.model.ptype]]):
85-
84+
async def prediction(
85+
input_data: Union[self.model.ptype, List[self.model.ptype]]
86+
):
87+
8688
if isinstance(input_data, List):
8789
served_data = _batch_data(input_data)
8890
else:
8991
served_data = _prepare_data(input_data)
9092

91-
y = self.model.handler_predict(served_data, check_ptype=self.check_ptype)
93+
y = self.model.handler_predict(
94+
served_data, check_ptype=self.check_ptype
95+
)
9296

9397
return {"prediction": y.tolist()}
9498

@@ -107,7 +111,7 @@ def vetiver_post(
107111
self, endpoint_fx: Callable, endpoint_name: str = "custom_endpoint"
108112
):
109113
"""Create new POST endpoint
110-
114+
111115
Parameters
112116
----------
113117
endpoint_fx : typing.Callable
@@ -138,12 +142,12 @@ async def custom_endpoint(input_data: Request):
138142
return {endpoint_name: new.tolist()}
139143

140144
def run(self):
141-
"""Start API
142-
"""
145+
"""Start API"""
143146
_jupyter_nb()
144147
uvicorn.run(self.app, port=self.port, host=self.host)
145148

146-
def predict(endpoint, data: dict):
149+
150+
def predict(endpoint, data: dict, **kw):
147151
"""Make a prediction from model endpoint
148152
149153
Parameters
@@ -158,7 +162,11 @@ def predict(endpoint, data: dict):
158162
dict
159163
Key: endpoint_name Value: Output of endpoint_fx, in list format
160164
"""
161-
response = requests.post(endpoint, json=data)
165+
if isinstance(data, pd.DataFrame):
166+
data = data.to_json(orient="records")
167+
response = requests.post(endpoint, data=data, **kw)
168+
else:
169+
response = requests.post(endpoint, json=data, **kw)
162170

163171
return response.json()
164172

@@ -169,28 +177,27 @@ def _prepare_data(pred_data):
169177
served_data.append(value)
170178
return served_data
171179

180+
172181
def _batch_data(pred_data):
173182
columns = pred_data[0].dict().keys()
174183

175184
data = [line.dict() for line in pred_data]
176-
print(data)
177185

178186
served_data = pd.DataFrame(data, columns=columns)
179187
return served_data
180188

181189

182-
183190
def vetiver_endpoint(url="http://127.0.0.1:8000/predict"):
184191
"""Wrap url
185192
186-
Parameters
187-
----------
188-
url : str
189-
URI path to endpoint
193+
Parameters
194+
----------
195+
url : str
196+
URI path to endpoint
190197
191-
Returns
192-
-------
193-
url : str
194-
URI path to endpoint
195-
"""
198+
Returns
199+
-------
200+
url : str
201+
URI path to endpoint
202+
"""
196203
return url

0 commit comments

Comments
 (0)