Skip to content

Commit 68635e7

Browse files
authored
Merge pull request #185 from rstudio/ptype-compat
2 parents c539ab8 + f675488 commit 68635e7

5 files changed

Lines changed: 36 additions & 12 deletions

File tree

.github/workflows/tests.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@ jobs:
3232
- name: Install dependencies
3333
run: |
3434
python -m pip install --upgrade pip
35-
python -m pip install -e ".[dev,all_models]"
35+
python -m pip install -e ".[dev]"
36+
python -m pip install xgboost
37+
python -m pip install spacy
38+
python -m pip install torch
39+
python -m pip install statsmodels
40+
python -m pip install typing_extensions==4.7.1
3641
- name: Run Tests
3742
run: |
3843
pytest -m 'not rsc_test and not docker' --cov --cov-report xml

vetiver/prototype.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _(data: pd.DataFrame):
117117
>>> another_prototype.schema() == prototype.schema()
118118
True
119119
"""
120-
dict_data = data.iloc[0, :].to_dict()
120+
dict_data = _to_field(data.iloc[0, :].to_dict())
121121
prototype = create_prototype(**dict_data)
122122
return prototype
123123

@@ -156,7 +156,11 @@ def _item(value):
156156

157157
dict_data = dict(enumerate(data[0], 0))
158158
# pydantic requires strings as indicies
159-
dict_data = {f"{key}": _item(value) for key, value in dict_data.items()}
159+
# if its a numpy type, we have to take the Python type due to Pydantic
160+
161+
dict_data = {
162+
f"{key}": (type(value.item()), _item(value)) for key, value in dict_data.items()
163+
}
160164
prototype = create_prototype(**dict_data)
161165
return prototype
162166

@@ -171,7 +175,7 @@ def _(data: dict):
171175
data : dict
172176
Dictionary
173177
"""
174-
return create_prototype(**data)
178+
return create_prototype(**_to_field(data))
175179

176180

177181
@vetiver_create_prototype.register
@@ -198,3 +202,10 @@ def _(data: NoneType):
198202
None
199203
"""
200204
return None
205+
206+
207+
def _to_field(data):
208+
basemodel_input = dict()
209+
for key, value in data.items():
210+
basemodel_input[key] = (type(value), value)
211+
return basemodel_input

vetiver/tests/test_build_vetiver_model.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
import pandas as pd
99
import pydantic
1010
import pins
11+
import numpy as np
12+
13+
np.random.seed(50)
1114

1215
# Load data, model
1316
X_df, y = get_mock_data()
@@ -33,8 +36,9 @@ def test_vetiver_model_array_prototype():
3336

3437
assert vt1.model == model
3538
assert issubclass(vt1.prototype, vt.Prototype)
39+
# change to model_construct for pydantic v3
3640
assert isinstance(vt1.prototype.construct(), pydantic.BaseModel)
37-
assert list(vt1.prototype.__fields__.values())[0].type_ == int
41+
assert vt1.prototype.construct().__dict__ == {"0": 96, "1": 11, "2": 33}
3842

3943

4044
def test_vetiver_model_df_prototype():
@@ -48,8 +52,9 @@ def test_vetiver_model_df_prototype():
4852
)
4953

5054
assert vt2.model == model
55+
# change to model_construct for pydantic v3
5156
assert isinstance(vt2.prototype.construct(), pydantic.BaseModel)
52-
assert list(vt2.prototype.__fields__.values())[0].type_ == int
57+
assert vt2.prototype.construct().B == 96
5358

5459

5560
def test_vetiver_model_dict_prototype():
@@ -64,8 +69,9 @@ def test_vetiver_model_dict_prototype():
6469
)
6570

6671
assert vt3.model == model
72+
# change to model_construct for pydantic v3
6773
assert isinstance(vt3.prototype.construct(), pydantic.BaseModel)
68-
assert list(vt3.prototype.__fields__.values())[0].type_ == int
74+
assert vt3.prototype.construct().B == 0
6975

7076

7177
def test_vetiver_model_basemodel_prototype():
@@ -135,6 +141,7 @@ def test_vetiver_model_from_pin():
135141

136142
assert isinstance(v2, vt.VetiverModel)
137143
assert isinstance(v2.model, sklearn.base.BaseEstimator)
144+
# change to model_construct for pydantic v3
138145
assert isinstance(v2.prototype.construct(), pydantic.BaseModel)
139146
assert v2.metadata.user == {"test": 123}
140147
assert v2.metadata.version is not None
@@ -170,6 +177,7 @@ def test_vetiver_model_from_pin_user_metadata():
170177

171178
assert isinstance(v2, vt.VetiverModel)
172179
assert isinstance(v2.model, sklearn.base.BaseEstimator)
180+
# change to model_construct for pydantic v3
173181
assert isinstance(v2.prototype.construct(), pydantic.BaseModel)
174182
assert v2.metadata.user == custom_meta
175183
assert v2.metadata.version is not None

vetiver/tests/test_custom_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def test_custom_vetiver_model():
4040
assert v.description == "A DummyRegressor model"
4141
assert not v.metadata.required_pkgs
4242
assert isinstance(v.model, sklearn.dummy.DummyRegressor)
43+
# change to model_construct for pydantic v3
4344
assert isinstance(v.prototype.construct(), pydantic.BaseModel)
4445

4546

@@ -58,4 +59,5 @@ def test_custom_vetiver_model_no_ptype():
5859

5960
assert v.description == "A regression model for testing purposes"
6061
assert isinstance(v.model, sklearn.dummy.DummyRegressor)
62+
# change to model_construct for pydantic v3
6163
assert isinstance(v.prototype.construct(), pydantic.BaseModel)

vetiver/tests/test_sklearn.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,10 @@ def test_predict_sklearn_no_ptype(data, length, vetiver_client_check_ptype_false
7373

7474
@pytest.mark.parametrize("data", [(0), 0, 0.0, "0"])
7575
def test_predict_sklearn_type_error(data, vetiver_client):
76-
import re
7776

78-
msg = re.sub(
79-
r"\n",
80-
": ",
81-
"1 validation error for Request\nbody\n value is not a valid list \(type=type_error.list\)", # noqa
77+
msg = str(
78+
"[{'type': 'list_type', 'loc': ('body',), 'msg': 'Input should be a valid list', \
79+
'input': '0', 'url': 'https://errors.pydantic.dev/2.0.3/v/list_type'}]"
8280
)
8381

8482
with pytest.raises(TypeError, match=msg):

0 commit comments

Comments
 (0)