3535)
3636
3737
38+ SIMPLE_NUMPY_TYPES = [nptype for type_cat , nptypes in np .sctypes .items ()
39+ for nptype in nptypes if type_cat != 'others' ]
40+ SIMPLE_TYPES = tuple ([bool , int , float , str ] + SIMPLE_NUMPY_TYPES )
41+
42+
3843def sklearn_to_flow (o , parent_model = None ):
3944 # TODO: assert that only on first recursion lvl `parent_model` can be None
40- simple_numpy_types = [nptype for type_cat , nptypes in np .sctypes .items ()
41- for nptype in nptypes
42- if type_cat != 'others' ]
43- simple_types = tuple ([bool , int , float , str ] + simple_numpy_types )
4445 if _is_estimator (o ):
4546 # is the main model or a submodel
4647 rval = _serialize_model (o )
@@ -49,8 +50,8 @@ def sklearn_to_flow(o, parent_model=None):
4950 rval = [sklearn_to_flow (element , parent_model ) for element in o ]
5051 if isinstance (o , tuple ):
5152 rval = tuple (rval )
52- elif isinstance (o , simple_types ) or o is None :
53- if isinstance (o , tuple (simple_numpy_types )):
53+ elif isinstance (o , SIMPLE_TYPES ) or o is None :
54+ if isinstance (o , tuple (SIMPLE_NUMPY_TYPES )):
5455 o = o .item ()
5556 # base parameter values
5657 rval = o
@@ -510,14 +511,36 @@ def _extract_information_from_model(model):
510511 for k , v in sorted (model_parameters .items (), key = lambda t : t [0 ]):
511512 rval = sklearn_to_flow (v , model )
512513
513- if (isinstance (rval , (list , tuple ))
514+ def flatten_all (list_ ):
515+ """ Flattens arbitrary depth lists of lists (e.g. [[1,2],[3,[1]]] -> [1,2,3,1]). """
516+ for el in list_ :
517+ if isinstance (el , (list , tuple )):
518+ yield from flatten_all (el )
519+ else :
520+ yield el
521+
522+ # In case rval is a list of lists (or tuples), we need to identify two situations:
523+ # - sklearn pipeline steps, feature union or base classifiers in voting classifier.
524+ # They look like e.g. [("imputer", Imputer()), ("classifier", SVC())]
525+ # - a list of lists with simple types (e.g. int or str), such as for an OrdinalEncoder
526+ # where all possible values for each feature are described: [[0,1,2], [1,2,5]]
527+ is_non_empty_list_of_lists_with_same_type = (
528+ isinstance (rval , (list , tuple ))
514529 and len (rval ) > 0
515530 and isinstance (rval [0 ], (list , tuple ))
516- and all ([isinstance (rval [ i ] , type (rval [0 ]))
517- for i in range ( len ( rval ))])):
531+ and all ([isinstance (rval_i , type (rval [0 ])) for rval_i in rval ] )
532+ )
518533
519- # Steps in a pipeline or feature union, or base classifiers in
520- # voting classifier
534+ # Check that all list elements are of simple types.
535+ nested_list_of_simple_types = (
536+ is_non_empty_list_of_lists_with_same_type
537+ and all ([isinstance (el , SIMPLE_TYPES ) for el in flatten_all (rval )])
538+ )
539+
540+ if is_non_empty_list_of_lists_with_same_type and not nested_list_of_simple_types :
541+ # If a list of lists is identified that include 'non-simple' types (e.g. objects),
542+ # we assume they are steps in a pipeline, feature union, or base classifiers in
543+ # a voting classifier.
521544 parameter_value = list ()
522545 reserved_keywords = set (model .get_params (deep = False ).keys ())
523546
@@ -597,7 +620,6 @@ def _extract_information_from_model(model):
597620 parameters [k ] = json .dumps (component_reference )
598621
599622 else :
600-
601623 # a regular hyperparameter
602624 if not (hasattr (rval , '__len__' ) and len (rval ) == 0 ):
603625 rval = json .dumps (rval )
0 commit comments