Skip to content

Commit 8f686e5

Browse files
committed
remove tolist from API
1 parent 0dbcf40 commit 8f686e5

5 files changed

Lines changed: 8 additions & 8 deletions

File tree

vetiver/handlers/sklearn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ def handler_predict(self, input_data, check_prototype):
4242
else:
4343
prediction = self.model.predict([input_data])
4444

45-
return prediction
45+
return prediction.tolist()

vetiver/handlers/statsmodels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ def handler_predict(self, input_data, check_prototype):
5151
else:
5252
prediction = self.model.predict([input_data])
5353

54-
return prediction
54+
return prediction.tolist()

vetiver/handlers/torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,4 @@ def handler_predict(self, input_data, check_prototype):
5353
input_data = torch.tensor(input_data)
5454
prediction = self.model(input_data)
5555

56-
return prediction
56+
return prediction.tolist()

vetiver/handlers/xgboost.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,4 @@ def handler_predict(self, input_data, check_prototype):
5656

5757
prediction = self.model.predict(input_data)
5858

59-
return prediction
59+
return prediction.tolist()

vetiver/server.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,17 +175,17 @@ def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw):
175175

176176
@self.app.post(urljoin("/", endpoint_name), name=endpoint_name)
177177
async def custom_endpoint(input_data: List[self.model.prototype]):
178-
new = endpoint_fx(input_data, **kw)
179-
return {endpoint_name: new.tolist()}
178+
predictions = endpoint_fx(input_data, **kw)
179+
return {endpoint_name: predictions}
180180

181181
else:
182182

183183
@self.app.post(urljoin("/", endpoint_name))
184184
async def custom_endpoint(input_data: Request):
185185
served_data = await input_data.json()
186-
new = endpoint_fx(served_data, **kw)
186+
predictions = endpoint_fx(served_data, **kw)
187187

188-
return {endpoint_name: new.tolist()}
188+
return {endpoint_name: predictions}
189189

190190
def run(self, port: int = 8000, host: str = "127.0.0.1", **kw):
191191
"""

0 commit comments

Comments
 (0)