Skip to content

Commit b2493de

Browse files
committed
MAINT remove step__ from component identifiers in pipeline
1 parent 60372ea commit b2493de

3 files changed

Lines changed: 46 additions & 40 deletions

File tree

openml/flows/sklearn_converter.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Convert scikit-learn estimators into an OpenMLFlows and vice versa."""
22

33
from collections import OrderedDict
4+
import copy
45
from distutils.version import LooseVersion
56
import importlib
67
import inspect
@@ -107,6 +108,10 @@ def flow_to_sklearn(o, **kwargs):
107108
step_name = value['step_name']
108109
key = value['key']
109110
component = flow_to_sklearn(kwargs['components'][key])
111+
# The component is now added to where it should be used
112+
# later. It should not be passed to the constructor of the
113+
# main flow object.
114+
del kwargs['components'][key]
110115
if step_name is None:
111116
rval = component
112117
else:
@@ -276,14 +281,13 @@ def _extract_information_from_model(model):
276281
# component reference as a placeholder to the list of
277282
# parameters, which will be replaced by the real component
278283
# when deserializing the parameter
279-
sub_component_identifier = k + '__' + identifier
280-
sub_components_explicit.add(sub_component_identifier)
281-
sub_components[sub_component_identifier] = sub_component
284+
sub_components_explicit.add(identifier)
285+
sub_components[identifier] = sub_component
282286
component_reference = OrderedDict()
283287
component_reference[
284288
'oml-python:serialized_object'] = 'component_reference'
285289
component_reference['value'] = OrderedDict(
286-
key=sub_component_identifier, step_name=identifier)
290+
key=identifier, step_name=identifier)
287291
parameter_value.append(component_reference)
288292

289293
if isinstance(rval, tuple):
@@ -331,25 +335,27 @@ def _deserialize_model(flow, **kwargs):
331335

332336
parameters = flow.parameters
333337
components = flow.components
334-
component_dict = OrderedDict()
335338
parameter_dict = OrderedDict()
336339

337-
for name in components:
338-
if '__' in name:
339-
parameter_name, step = name.split('__')
340-
value = components[name]
341-
rval = flow_to_sklearn(value)
342-
if parameter_name not in component_dict:
343-
component_dict[parameter_name] = OrderedDict()
344-
component_dict[parameter_name][step] = rval
345-
else:
346-
value = components[name]
347-
rval = flow_to_sklearn(value)
348-
parameter_dict[name] = rval
340+
# Do a shallow copy of the components dictionary so we can remove the
341+
# components from this copy once we added them into the pipeline. This
342+
# allows us to not consider them any more when looping over the
343+
# components, but keeping the dictionary of components untouched in the
344+
# original components dictionary.
345+
components_ = copy.copy(components)
349346

350347
for name in parameters:
351348
value = parameters.get(name)
352-
rval = flow_to_sklearn(value, components=components)
349+
rval = flow_to_sklearn(value, components=components_)
350+
parameter_dict[name] = rval
351+
352+
for name in components:
353+
if name in parameter_dict:
354+
continue
355+
if name not in components_:
356+
continue
357+
value = components[name]
358+
rval = flow_to_sklearn(value)
353359
parameter_dict[name] = rval
354360

355361
module_name = model_name.rsplit('.', 1)

tests/test_flows/test_flow.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -243,12 +243,12 @@ def test_sklearn_to_upload_to_flow(self):
243243

244244
fixture_name = '%ssklearn.model_selection._search.RandomizedSearchCV(' \
245245
'estimator=sklearn.pipeline.Pipeline(' \
246-
'steps__ohe=sklearn.preprocessing.data.OneHotEncoder,' \
247-
'steps__scaler=sklearn.preprocessing.data.StandardScaler,' \
248-
'steps__fu=sklearn.pipeline.FeatureUnion(' \
249-
'transformer_list__pca=sklearn.decomposition.truncated_svd.TruncatedSVD,' \
250-
'transformer_list__fs=sklearn.feature_selection.univariate_selection.SelectPercentile),' \
251-
'steps__boosting=sklearn.ensemble.weight_boosting.AdaBoostClassifier(' \
246+
'ohe=sklearn.preprocessing.data.OneHotEncoder,' \
247+
'scaler=sklearn.preprocessing.data.StandardScaler,' \
248+
'fu=sklearn.pipeline.FeatureUnion(' \
249+
'pca=sklearn.decomposition.truncated_svd.TruncatedSVD,' \
250+
'fs=sklearn.feature_selection.univariate_selection.SelectPercentile),' \
251+
'boosting=sklearn.ensemble.weight_boosting.AdaBoostClassifier(' \
252252
'base_estimator=sklearn.tree.tree.DecisionTreeClassifier)))' \
253253
% sentinel
254254

tests/test_flows/test_sklearn.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ def test_serialize_pipeline(self):
145145
('scaler', scaler), ('dummy', dummy)))
146146

147147
fixture_name = 'sklearn.pipeline.Pipeline(' \
148-
'steps__scaler=sklearn.preprocessing.data.StandardScaler,' \
149-
'steps__dummy=sklearn.dummy.DummyClassifier)'
148+
'scaler=sklearn.preprocessing.data.StandardScaler,' \
149+
'dummy=sklearn.dummy.DummyClassifier)'
150150
fixture_description = 'Automatically created sub-component.'
151151

152152
serialization = sklearn_to_flow(model)
@@ -162,15 +162,15 @@ def test_serialize_pipeline(self):
162162
# different sorting. Making a json makes it easier
163163
self.assertEqual(json.loads(serialization.parameters['steps']),
164164
[{'oml-python:serialized_object':
165-
'component_reference', 'value': {'key': 'steps__scaler', 'step_name': 'scaler'}},
165+
'component_reference', 'value': {'key': 'scaler', 'step_name': 'scaler'}},
166166
{'oml-python:serialized_object':
167-
'component_reference', 'value': {'key': 'steps__dummy', 'step_name': 'dummy'}}])
167+
'component_reference', 'value': {'key': 'dummy', 'step_name': 'dummy'}}])
168168

169169
# Checking the sub-component
170170
self.assertEqual(len(serialization.components), 2)
171-
self.assertIsInstance(serialization.components['steps__scaler'],
171+
self.assertIsInstance(serialization.components['scaler'],
172172
OpenMLFlow)
173-
self.assertIsInstance(serialization.components['steps__dummy'],
173+
self.assertIsInstance(serialization.components['dummy'],
174174
OpenMLFlow)
175175

176176
#del serialization.model
@@ -204,8 +204,8 @@ def test_serialize_feature_union(self):
204204
serialization = sklearn_to_flow(fu)
205205
self.assertEqual(serialization.name,
206206
'sklearn.pipeline.FeatureUnion('
207-
'transformer_list__ohe=sklearn.preprocessing.data.OneHotEncoder,'
208-
'transformer_list__scaler=sklearn.preprocessing.data.StandardScaler)')
207+
'ohe=sklearn.preprocessing.data.OneHotEncoder,'
208+
'scaler=sklearn.preprocessing.data.StandardScaler)')
209209
new_model = flow_to_sklearn(serialization)
210210

211211
self.assertEqual(type(new_model), type(fu))
@@ -240,7 +240,7 @@ def test_serialize_feature_union(self):
240240
serialization = sklearn_to_flow(fu)
241241
self.assertEqual(serialization.name,
242242
'sklearn.pipeline.FeatureUnion('
243-
'transformer_list__ohe=sklearn.preprocessing.data.OneHotEncoder)')
243+
'ohe=sklearn.preprocessing.data.OneHotEncoder)')
244244
new_model = flow_to_sklearn(serialization)
245245
self.assertEqual(type(new_model), type(fu))
246246
self.assertIsNot(new_model, fu)
@@ -256,13 +256,13 @@ def test_serialize_feature_union_switched_names(self):
256256
self.assertEqual(
257257
fu1_serialization.name,
258258
"sklearn.pipeline.FeatureUnion("
259-
"transformer_list__ohe=sklearn.preprocessing.data.OneHotEncoder,"
260-
"transformer_list__scaler=sklearn.preprocessing.data.StandardScaler)")
259+
"ohe=sklearn.preprocessing.data.OneHotEncoder,"
260+
"scaler=sklearn.preprocessing.data.StandardScaler)")
261261
self.assertEqual(
262262
fu2_serialization.name,
263263
"sklearn.pipeline.FeatureUnion("
264-
"transformer_list__scaler=sklearn.preprocessing.data.OneHotEncoder,"
265-
"transformer_list__ohe=sklearn.preprocessing.data.StandardScaler)")
264+
"scaler=sklearn.preprocessing.data.OneHotEncoder,"
265+
"ohe=sklearn.preprocessing.data.StandardScaler)")
266266

267267
def test_serialize_complex_flow(self):
268268
ohe = sklearn.preprocessing.OneHotEncoder(categorical_features=[0])
@@ -282,9 +282,9 @@ def test_serialize_complex_flow(self):
282282

283283
fixture_name = 'sklearn.model_selection._search.RandomizedSearchCV(' \
284284
'estimator=sklearn.pipeline.Pipeline(' \
285-
'steps__ohe=sklearn.preprocessing.data.OneHotEncoder,' \
286-
'steps__scaler=sklearn.preprocessing.data.StandardScaler,' \
287-
'steps__boosting=sklearn.ensemble.weight_boosting.AdaBoostClassifier(' \
285+
'ohe=sklearn.preprocessing.data.OneHotEncoder,' \
286+
'scaler=sklearn.preprocessing.data.StandardScaler,' \
287+
'boosting=sklearn.ensemble.weight_boosting.AdaBoostClassifier(' \
288288
'base_estimator=sklearn.tree.tree.DecisionTreeClassifier)))'
289289
self.assertEqual(serialized.name, fixture_name)
290290

0 commit comments

Comments
 (0)