Skip to content

Commit 9c74931

Browse files
authored
Merge pull request #637 from openml/fix636
Fix636
2 parents 0235c51 + b98987e commit 9c74931

2 files changed

Lines changed: 57 additions & 12 deletions

File tree

openml/flows/sklearn_converter.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@
3535
)
3636

3737

38+
SIMPLE_NUMPY_TYPES = [nptype for type_cat, nptypes in np.sctypes.items()
39+
for nptype in nptypes if type_cat != 'others']
40+
SIMPLE_TYPES = tuple([bool, int, float, str] + SIMPLE_NUMPY_TYPES)
41+
42+
3843
def sklearn_to_flow(o, parent_model=None):
3944
# TODO: assert that only on first recursion lvl `parent_model` can be None
40-
simple_numpy_types = [nptype for type_cat, nptypes in np.sctypes.items()
41-
for nptype in nptypes
42-
if type_cat != 'others']
43-
simple_types = tuple([bool, int, float, str] + simple_numpy_types)
4445
if _is_estimator(o):
4546
# is the main model or a submodel
4647
rval = _serialize_model(o)
@@ -49,8 +50,8 @@ def sklearn_to_flow(o, parent_model=None):
4950
rval = [sklearn_to_flow(element, parent_model) for element in o]
5051
if isinstance(o, tuple):
5152
rval = tuple(rval)
52-
elif isinstance(o, simple_types) or o is None:
53-
if isinstance(o, tuple(simple_numpy_types)):
53+
elif isinstance(o, SIMPLE_TYPES) or o is None:
54+
if isinstance(o, tuple(SIMPLE_NUMPY_TYPES)):
5455
o = o.item()
5556
# base parameter values
5657
rval = o
@@ -510,14 +511,36 @@ def _extract_information_from_model(model):
510511
for k, v in sorted(model_parameters.items(), key=lambda t: t[0]):
511512
rval = sklearn_to_flow(v, model)
512513

513-
if (isinstance(rval, (list, tuple))
514+
def flatten_all(list_):
515+
""" Flattens arbitrary depth lists of lists (e.g. [[1,2],[3,[1]]] -> [1,2,3,1]). """
516+
for el in list_:
517+
if isinstance(el, (list, tuple)):
518+
yield from flatten_all(el)
519+
else:
520+
yield el
521+
522+
# In case rval is a list of lists (or tuples), we need to identify two situations:
523+
# - sklearn pipeline steps, feature union or base classifiers in voting classifier.
524+
# They look like e.g. [("imputer", Imputer()), ("classifier", SVC())]
525+
# - a list of lists with simple types (e.g. int or str), such as for an OrdinalEncoder
526+
# where all possible values for each feature are described: [[0,1,2], [1,2,5]]
527+
is_non_empty_list_of_lists_with_same_type = (
528+
isinstance(rval, (list, tuple))
514529
and len(rval) > 0
515530
and isinstance(rval[0], (list, tuple))
516-
and all([isinstance(rval[i], type(rval[0]))
517-
for i in range(len(rval))])):
531+
and all([isinstance(rval_i, type(rval[0])) for rval_i in rval])
532+
)
518533

519-
# Steps in a pipeline or feature union, or base classifiers in
520-
# voting classifier
534+
# Check that all list elements are of simple types.
535+
nested_list_of_simple_types = (
536+
is_non_empty_list_of_lists_with_same_type
537+
and all([isinstance(el, SIMPLE_TYPES) for el in flatten_all(rval)])
538+
)
539+
540+
if is_non_empty_list_of_lists_with_same_type and not nested_list_of_simple_types:
541+
# If a list of lists is identified that include 'non-simple' types (e.g. objects),
542+
# we assume they are steps in a pipeline, feature union, or base classifiers in
543+
# a voting classifier.
521544
parameter_value = list()
522545
reserved_keywords = set(model.get_params(deep=False).keys())
523546

@@ -597,7 +620,6 @@ def _extract_information_from_model(model):
597620
parameters[k] = json.dumps(component_reference)
598621

599622
else:
600-
601623
# a regular hyperparameter
602624
if not (hasattr(rval, '__len__') and len(rval) == 0):
603625
rval = json.dumps(rval)

tests/test_flows/test_flow_functions.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from collections import OrderedDict
22
import copy
3+
import unittest
4+
5+
from distutils.version import LooseVersion
6+
import sklearn
37

48
import openml
59
from openml.testing import TestBase
@@ -221,3 +225,22 @@ def test_are_flows_equal_ignore_if_older(self):
221225
self.assertRaises(ValueError, assert_flows_equal, flow, new_flow,
222226
ignore_parameter_values_on_older_children=flow_upload_date)
223227
assert_flows_equal(flow, flow, ignore_parameter_values_on_older_children=None)
228+
229+
@unittest.skipIf(LooseVersion(sklearn.__version__) < "0.20",
230+
reason="OrdinalEncoder introduced in 0.20. "
231+
"No known models with list of lists parameters in older versions.")
232+
def test_sklearn_to_flow_list_of_lists(self):
233+
from sklearn.preprocessing import OrdinalEncoder
234+
ordinal_encoder = OrdinalEncoder(categories=[[0, 1], [0, 1]])
235+
236+
# Test serialization works
237+
flow = openml.flows.sklearn_to_flow(ordinal_encoder)
238+
239+
# Test flow is accepted by server
240+
self._add_sentinel_to_flow_name(flow)
241+
flow.publish()
242+
243+
# Test deserialization works
244+
server_flow = openml.flows.get_flow(flow.flow_id, reinstantiate=True)
245+
self.assertEqual(server_flow.parameters['categories'], '[[0, 1], [0, 1]]')
246+
self.assertEqual(server_flow.model.categories, flow.model.categories)

0 commit comments

Comments
 (0)