Skip to content

Commit e0878f9

Browse files
has2k1isabelizimm
authored andcommitted
Use single_dispatch to check the data prototype
This should allow users to write functions for custom datatypes and anyothers that we may not handle.
1 parent 07741fe commit e0878f9

1 file changed

Lines changed: 50 additions & 34 deletions

File tree

vetiver/ptype.py

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from functools import singledispatch
2+
from types import NoneType
3+
14
import pandas as pd
25
import numpy as np
36
from pydantic import BaseModel, create_model
47

8+
59
class NoAvailablePTypeError(Exception):
610
"""
711
Throw an error if we cannot create
@@ -30,60 +34,72 @@ def __init__(
3034
super().__init__(self.message)
3135

3236

33-
def vetiver_create_ptype(ptype_data, save_ptype: bool):
37+
CREATE_PTYPE_TPL = """\
38+
Failed to create a data prototype (ptype) from data of \
39+
type {_data_type}. If your datatype is not one of \
40+
(pd.DataFrame, pydantic.BaseModel, np.ndarry, dict), \
41+
you should write a function to create the ptype. Here is \
42+
a template for such a function: \
43+
44+
from pydantic import create_model
45+
from vetiver.ptype import vetiver_create_ptype
46+
47+
@vetiver_create_ptype.register
48+
def _(data: {_data_type}):
49+
data_dict = ... # convert data to a dictionary
50+
ptype = create_model("ptype", **data_dict)
51+
return ptype
52+
53+
If your datatype is a common type, please consider submitting \
54+
a pull request.
55+
"""
56+
57+
@singledispatch
58+
def vetiver_create_ptype(data, save_ptype):
3459
"""Create zero row structure to save data types
60+
3561
Parameters
3662
----------
37-
ptype_data :
63+
data :
3864
Data that represents what
39-
save_ptype : bool
40-
Whether or not ptype should be created
4165
4266
Returns
4367
-------
4468
ptype
4569
Data prototype
4670
4771
"""
48-
ptype = None
49-
50-
if save_ptype == False:
51-
pass
52-
elif save_ptype == True:
53-
try:
54-
if isinstance(ptype_data, np.ndarray):
55-
ptype = _array_to_ptype(ptype_data[1])
56-
elif isinstance(ptype_data, dict):
57-
ptype = _dict_to_ptype(ptype_data)
58-
elif isinstance(ptype_data.construct(), BaseModel):
59-
ptype = ptype_data
60-
except AttributeError: # cannot construct basemodel
61-
if isinstance(ptype_data, pd.DataFrame):
62-
ptype = _df_to_ptype(ptype_data.iloc[1, :])
63-
else:
64-
raise InvalidPTypeError
65-
66-
return ptype
67-
72+
msg = CREATE_PTYPE_TPL.format(_data_type=type(data))
73+
msg = ""
74+
raise InvalidPTypeError(message=msg)
6875

69-
def _df_to_ptype(train_data):
7076

71-
dict_data = train_data.to_dict()
77+
@vetiver_create_ptype.register
78+
def _vetiver_create_ptype(data: pd.DataFrame, save_ptype):
79+
dict_data = data.iloc[1, :].to_dict()
7280
ptype = create_model("ptype", **dict_data)
73-
7481
return ptype
7582

7683

77-
def _array_to_ptype(train_data):
78-
dict_data = dict(enumerate(train_data, 0))
79-
84+
@vetiver_create_ptype.register
85+
def _vetiver_create_ptype(data: np.ndarray, save_ptype):
86+
dict_data = dict(enumerate(data[1], 0))
8087
# pydantic requires strings as indicies
81-
dict_data = {str(key): value.item() for key, value in dict_data.items()}
88+
dict_data = {f"{key}": value.item() for key, value in dict_data.items()}
8289
ptype = create_model("ptype", **dict_data)
83-
8490
return ptype
8591

8692

87-
def _dict_to_ptype(train_data):
93+
@vetiver_create_ptype.register
94+
def _vetiver_create_ptype(data: dict, save_ptype):
95+
return create_model("ptype", **data)
96+
97+
98+
@vetiver_create_ptype.register
99+
def _vetiver_create_ptype(data: BaseModel, save_ptype):
100+
return data
101+
88102

89-
return create_model("ptype",**train_data)
103+
@vetiver_create_ptype.register
104+
def _vetiver_create_ptype(data: NoneType, save_ptype):
105+
return None

0 commit comments

Comments
 (0)