Skip to content

Commit 604d01e

Browse files
committed
FIX serialize list of integers which are argument to sklearn model
1 parent 7e6a545 commit 604d01e

2 files changed

Lines changed: 7 additions & 4 deletions

File tree

openml/flows/sklearn_converter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def serialize_object(self, o):
7373
elif self._is_cross_validator(o):
7474
rval = self.serialize_cross_validator(o)
7575
else:
76-
raise TypeError(o)
76+
raise TypeError(o, type(o))
7777

7878
assert o is None or rval is not None
7979

@@ -201,7 +201,9 @@ def _serialize_model(self, model):
201201
if k not in model_parameters.parameters:
202202
continue
203203

204-
if isinstance(rval, (list, tuple)):
204+
if isinstance(rval, (list, tuple)) and \
205+
isinstance(rval[0], (list, tuple)) and \
206+
[type(rval[0]) == type(rval[i]) for i in range(len(rval))]:
205207

206208
# Steps in a pipeline or feature union
207209
parameter_value = list()

tests/flows/test_sklearn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,12 @@ def test_serialize_feature_union(self):
198198
new_model.fit(self.X, self.y)
199199

200200
def test_serialize_complex_flow(self):
201+
ohe = sklearn.preprocessing.OneHotEncoder(categorical_features=[0])
201202
scaler = sklearn.preprocessing.StandardScaler(with_mean=False)
202-
203203
boosting = sklearn.ensemble.AdaBoostClassifier(
204204
base_estimator=sklearn.tree.DecisionTreeClassifier())
205205
model = sklearn.pipeline.Pipeline(steps=(
206-
('scaler', scaler), ('boosting', boosting)))
206+
('ohe', ohe), ('scaler', scaler), ('boosting', boosting)))
207207
parameter_grid = {'n_estimators': [1, 5, 10, 100],
208208
'learning_rate': scipy.stats.uniform(0.01, 0.99),
209209
'base_estimator__max_depth': scipy.stats.randint(1,
@@ -216,6 +216,7 @@ def test_serialize_complex_flow(self):
216216
fixture_name = 'sklearn.model_selection._search.RandomizedSearchCV(' \
217217
'sklearn.model_selection._split.StratifiedKFold,' \
218218
'sklearn.pipeline.Pipeline(' \
219+
'sklearn.preprocessing.data.OneHotEncoder,' \
219220
'sklearn.preprocessing.data.StandardScaler,' \
220221
'sklearn.ensemble.weight_boosting.AdaBoostClassifier(' \
221222
'sklearn.tree.tree.DecisionTreeClassifier)))'

0 commit comments

Comments
 (0)