@@ -696,10 +696,14 @@ def _serialize_model(self, model: Any) -> OpenMLFlow:
696696 # will be part of the name (in brackets)
697697 sub_components_names = ""
698698 for key in subcomponents :
699+ if isinstance (subcomponents [key ], OpenMLFlow ):
700+ name = subcomponents [key ].name
701+ elif isinstance (subcomponents [key ], str ): # 'drop', 'passthrough' can be passed
702+ name = subcomponents [key ]
699703 if key in subcomponents_explicit :
700- sub_components_names += "," + key + "=" + subcomponents [ key ]. name
704+ sub_components_names += "," + key + "=" + name
701705 else :
702- sub_components_names += "," + subcomponents [ key ]. name
706+ sub_components_names += "," + name
703707
704708 if sub_components_names :
705709 # slice operation on string in order to get rid of leading comma
@@ -771,6 +775,9 @@ def _get_external_version_string(
771775 external_versions .add (openml_version )
772776 external_versions .add (sklearn_version )
773777 for visitee in sub_components .values ():
778+ # 'drop', 'passthrough', None can be passed as estimators
779+ if isinstance (visitee , str ):
780+ continue
774781 for external_version in visitee .external_version .split (',' ):
775782 external_versions .add (external_version )
776783 return ',' .join (list (sorted (external_versions )))
@@ -783,9 +790,12 @@ def _check_multiple_occurence_of_component_in_flow(
783790 to_visit_stack = [] # type: List[OpenMLFlow]
784791 to_visit_stack .extend (sub_components .values ())
785792 known_sub_components = set () # type: Set[str]
793+
786794 while len (to_visit_stack ) > 0 :
787795 visitee = to_visit_stack .pop ()
788- if visitee .name in known_sub_components :
796+ if isinstance (visitee , str ): # 'drop', 'passthrough' can be passed as estimators
797+ known_sub_components .add (visitee )
798+ elif visitee .name in known_sub_components :
789799 raise ValueError ('Found a second occurence of component %s when '
790800 'trying to serialize %s.' % (visitee .name , model ))
791801 else :
@@ -822,7 +832,7 @@ def _extract_information_from_model(
822832 def flatten_all (list_ ):
823833 """ Flattens arbitrary depth lists of lists (e.g. [[1,2],[3,[1]]] -> [1,2,3,1]). """
824834 for el in list_ :
825- if isinstance (el , (list , tuple )):
835+ if isinstance (el , (list , tuple )) and len ( el ) > 0 :
826836 yield from flatten_all (el )
827837 else :
828838 yield el
@@ -852,17 +862,31 @@ def flatten_all(list_):
852862 parameter_value = list () # type: List
853863 reserved_keywords = set (model .get_params (deep = False ).keys ())
854864
855- for sub_component_tuple in rval :
865+ for i , sub_component_tuple in enumerate ( rval ) :
856866 identifier = sub_component_tuple [0 ]
857867 sub_component = sub_component_tuple [1 ]
858- sub_component_type = type (sub_component_tuple )
868+ # sub_component_type = type(sub_component_tuple)
859869 if not 2 <= len (sub_component_tuple ) <= 3 :
860870 # length 2 is for {VotingClassifier.estimators,
861871 # Pipeline.steps, FeatureUnion.transformer_list}
862872 # length 3 is for ColumnTransformer
863873 msg = 'Length of tuple does not match assumptions'
864874 raise ValueError (msg )
865- if not isinstance (sub_component , (OpenMLFlow , type (None ))):
875+
876+ if isinstance (sub_component , str ):
877+ if sub_component != 'drop' and sub_component != 'passthrough' :
878+ msg = 'Second item of tuple does not match assumptions. ' \
879+ 'If string, can be only \' drop\' or \' passthrough\' but' \
880+ 'got %s' % sub_component
881+ raise ValueError (msg )
882+ else :
883+ pass
884+ elif isinstance (sub_component , type (None )):
885+ msg = 'Cannot serialize objects of None type. Please use a valid ' \
886+ 'placeholder for None. Note that empty sklearn estimators can be ' \
887+ 'replaced with \' drop\' or \' passthrough\' .'
888+ raise ValueError (msg )
889+ elif not isinstance (sub_component , OpenMLFlow ):
866890 msg = 'Second item of tuple does not match assumptions. ' \
867891 'Expected OpenMLFlow, got %s' % type (sub_component )
868892 raise TypeError (msg )
@@ -875,31 +899,18 @@ def flatten_all(list_):
875899 identifier )
876900 raise PyOpenMLError (msg )
877901
878- if sub_component is None :
879- # In a FeatureUnion it is legal to have a None step
880-
881- pv = [identifier , None ]
882- if sub_component_type is tuple :
883- parameter_value .append (tuple (pv ))
884- else :
885- parameter_value .append (pv )
886-
887- else :
888- # Add the component to the list of components, add a
889- # component reference as a placeholder to the list of
890- # parameters, which will be replaced by the real component
891- # when deserializing the parameter
892- sub_components_explicit .add (identifier )
893- sub_components [identifier ] = sub_component
894- component_reference = OrderedDict () # type: Dict[str, Union[str, Dict]]
895- component_reference ['oml-python:serialized_object' ] = 'component_reference'
896- cr_value = OrderedDict () # type: Dict[str, Any]
897- cr_value ['key' ] = identifier
898- cr_value ['step_name' ] = identifier
899- if len (sub_component_tuple ) == 3 :
900- cr_value ['argument_1' ] = sub_component_tuple [2 ]
901- component_reference ['value' ] = cr_value
902- parameter_value .append (component_reference )
902+ # when deserializing the parameter
903+ sub_components_explicit .add (identifier )
904+ sub_components [identifier ] = sub_component
905+ component_reference = OrderedDict () # type: Dict[str, Union[str, Dict]]
906+ component_reference ['oml-python:serialized_object' ] = 'component_reference'
907+ cr_value = OrderedDict () # type: Dict[str, Any]
908+ cr_value ['key' ] = identifier
909+ cr_value ['step_name' ] = identifier
910+ if len (sub_component_tuple ) == 3 :
911+ cr_value ['argument_1' ] = sub_component_tuple [2 ]
912+ component_reference ['value' ] = cr_value
913+ parameter_value .append (component_reference )
903914
904915 # Here (and in the elif and else branch below) are the only
905916 # places where we encode a value as json to make sure that all
0 commit comments