Skip to content

Commit 053289d

Browse files
committed
singledispatch _prepare_data
1 parent 4da0740 commit 053289d

1 file changed

Lines changed: 18 additions & 14 deletions

File tree

vetiver/handlers/base.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import List
21
from functools import singledispatch
32
from contextlib import suppress
43
import pandas as pd
4+
import pydantic
55

66
from ..prototype import vetiver_create_prototype
77
from ..meta import VetiverMeta
@@ -119,30 +119,34 @@ def handler_startup():
119119
"""
120120
...
121121

122+
@singledispatch
122123
def _prepare_data(self, pred_data) -> pd.DataFrame:
123124
"""Convert prototype to dataframe data
124125
125126
Parameters
126127
----------
127-
prototype_data : pd.DataFrame, np.ndarray, or None
128-
Training data to create prototype
128+
pred_data : pydantic.BaseModel
129+
User data from given to API endpoint
129130
130131
Returns
131132
-------
132-
prototype : pd.DataFrame or None
133-
Zero-row DataFrame for storing data types
134-
"""
135-
if isinstance(pred_data, List):
136-
columns = pred_data[0].dict().keys()
137-
data = [line.dict() for line in pred_data]
138-
served_data = pd.DataFrame(data, columns=columns)
139-
else:
140-
served_data = []
141-
for key, value in pred_data:
142-
served_data.append(value)
133+
pd.DataFrame
134+
BaseModel data translated into DataFrame
135+
"""
136+
served_data = []
137+
for key, value in pred_data:
138+
served_data.append(value)
143139

144140
return served_data
145141

142+
@_prepare_data.register
143+
def _basemodel(self, pred_data: pydantic.BaseModel):
144+
return pd.DataFrame([dict(s) for s in pred_data])
145+
146+
@_prepare_data.register
147+
def _list(self, pred_data: list):
148+
return pd.DataFrame([dict(s) for s in pred_data])
149+
146150
def handler_predict(self, input_data, check_prototype):
147151
"""Generates method for /predict endpoint in VetiverAPI
148152

0 commit comments

Comments
 (0)