Skip to content

Commit 815c259

Browse files
committed
MAINT improve code based on Andreas' suggestions
1 parent c23c00b commit 815c259

10 files changed

Lines changed: 124 additions & 91 deletions

File tree

openml/flows/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .flow import OpenMLFlow
2-
from .functions import get_flow
32
from .sklearn_converter import sklearn_to_flow, flow_to_sklearn
3+
from .functions import get_flow
44

55
__all__ = ['OpenMLFlow', 'create_flow_from_model', 'get_flow',
66
'sklearn_to_flow', 'flow_to_sklearn']

openml/flows/flow.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -145,17 +145,17 @@ def _to_xml(self):
145145
return flow_xml
146146

147147
def _to_dict(self):
148-
""" Helper function used by _to_xml and _to_dict.
148+
""" Helper function used by _to_xml and itself.
149149
150150
Creates a dictionary representation of self which can be serialized
151151
to xml by the function _to_xml. Since a flow can contain subflows
152152
(components) this helper function calls itself recursively to also
153153
serialize these flows to dictionaries.
154154
155-
Uses OrderedDict everywhere to make sure that the order of data stays
156-
at it is added here. The return value (OrderedDict) will be used to
157-
create the upload xml file. The xml file must have the tags in exactly
158-
the order given in the xsd schema of a flow (see class docstring).
155+
Uses OrderedDict to ensure consistent ordering when converting to xml.
156+
The return value (OrderedDict) will be used to create the upload xml
157+
file. The xml file must have the tags in exactly the order given in the
158+
xsd schema of a flow (see class docstring).
159159
160160
Returns
161161
-------
@@ -324,30 +324,6 @@ def _from_dict(cls, xml_dict):
324324
arguments['model'] = None
325325
return cls(**arguments)
326326

327-
def __eq__(self, other):
328-
"""Check equality.
329-
330-
Two flows are equal if their all keys which are not set by the server
331-
are equal, as well as all their parameters and components.
332-
"""
333-
if not isinstance(other, self.__class__):
334-
return NotImplemented
335-
336-
# Name is actually not generated by the server, but it will be
337-
# tested further down with a getter (allows mocking in the tests)
338-
generated_by_the_server = ['name', 'flow_id', 'uploader', 'version',
339-
'upload_date', 'source_url',
340-
'binary_url', 'source_format',
341-
'binary_format', 'source_md5',
342-
'binary_md5', 'model']
343-
344-
for key in set(self.__dict__.keys()).union(other.__dict__.keys()):
345-
if key in generated_by_the_server:
346-
continue
347-
if getattr(self, key, None) != getattr(other, key, None):
348-
return False
349-
return True
350-
351327
def publish(self):
352328
"""Publish flow to OpenML server.
353329

openml/flows/functions.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import xmltodict
22

33
from openml._api_calls import _perform_api_call
4-
# Absolute imports, to avoid circular dependencies
5-
from openml.flows.sklearn_converter import flow_to_sklearn
6-
from . import OpenMLFlow
4+
from . import OpenMLFlow, flow_to_sklearn
75

86

97
def get_flow(flow_id):

openml/flows/sklearn_converter.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from collections import OrderedDict, defaultdict
1+
"""Convert scikit-learn estimators into an OpenMLFlows and vice versa."""
2+
3+
from collections import OrderedDict
24
import importlib
35
import inspect
46
import json
@@ -23,9 +25,6 @@
2325
JSONDecodeError = ValueError
2426

2527

26-
"""Convert scikit-learn estimators into an OpenMLFlows and vice versa."""
27-
28-
2928
def sklearn_to_flow(o):
3029

3130
if _is_estimator(o):
@@ -52,7 +51,7 @@ def sklearn_to_flow(o):
5251
elif isinstance(o, scipy.stats.distributions.rv_frozen):
5352
rval = serialize_rv_frozen(o)
5453
# This only works for user-defined functions (and not even partial).
55-
# I think this is exactly we want here as there shouldn't be any
54+
# I think this is exactly what we want here as there shouldn't be any
5655
# built-in or functool.partials in a pipeline
5756
elif inspect.isfunction(o):
5857
rval = serialize_function(o)
@@ -126,7 +125,6 @@ def flow_to_sklearn(o, **kwargs):
126125
rval = _deserialize_model(o, **kwargs)
127126
else:
128127
raise TypeError(o)
129-
assert o is None or rval is not None
130128

131129
return rval
132130

@@ -153,17 +151,7 @@ def _serialize_model(model):
153151

154152
# Check that a component does not occur multiple times in a flow as this
155153
# is not supported by OpenML
156-
to_visit_stack = []
157-
to_visit_stack.extend(sub_components.values())
158-
known_sub_components = set()
159-
while len(to_visit_stack) > 0:
160-
visitee = to_visit_stack.pop()
161-
if visitee.name in known_sub_components:
162-
raise ValueError('Found a second occurence of component %s when '
163-
'trying to serialize %s.' % (visitee.name, model))
164-
else:
165-
known_sub_components.add(visitee.name)
166-
to_visit_stack.extend(visitee.components.values())
154+
_check_multiple_occurence_of_component_in_flow(model, sub_components)
167155

168156
# Create a flow name, which contains all components in brackets, for
169157
# example RandomizedSearchCV(Pipeline(StandardScaler,AdaBoostClassifier(DecisionTreeClassifier)),StandardScaler,AdaBoostClassifier(DecisionTreeClassifier))
@@ -184,22 +172,7 @@ def _serialize_model(model):
184172
name = class_name
185173

186174
# Get the external versions of all sub-components
187-
model_package_name = model.__module__.split('.')[0]
188-
module = importlib.import_module(model_package_name)
189-
model_package_version_number = module.__version__
190-
external_version = _format_external_version(model_package_name, model_package_version_number)
191-
192-
external_versions = set()
193-
external_versions.add(external_version)
194-
to_visit_stack = []
195-
to_visit_stack.extend(sub_components.values())
196-
while len(to_visit_stack) > 0:
197-
visitee = to_visit_stack.pop()
198-
for external_version in visitee.external_version.split(','):
199-
external_versions.add(external_version)
200-
to_visit_stack.extend(visitee.components.values())
201-
external_versions = list(sorted(external_versions))
202-
external_version = ','.join(external_versions)
175+
external_version = _get_external_version_string(model, sub_components)
203176

204177
flow = OpenMLFlow(name=name,
205178
class_name=class_name,
@@ -217,6 +190,41 @@ def _serialize_model(model):
217190
return flow
218191

219192

193+
def _get_external_version_string(model, sub_components):
194+
# Create external version string for a flow, given the model and the
195+
# already parsed dictionary of sub_components. Retrieves the external
196+
# version of all subcomponents, which themselves already contain all
197+
# requirements for their subcomponents. The external version string is a
198+
# sorted concatenation of all modules which are present in this run.
199+
model_package_name = model.__module__.split('.')[0]
200+
module = importlib.import_module(model_package_name)
201+
model_package_version_number = module.__version__
202+
external_version = _format_external_version(model_package_name,
203+
model_package_version_number)
204+
external_versions = set()
205+
external_versions.add(external_version)
206+
for visitee in sub_components.values():
207+
for external_version in visitee.external_version.split(','):
208+
external_versions.add(external_version)
209+
external_versions = list(sorted(external_versions))
210+
external_version = ','.join(external_versions)
211+
return external_version
212+
213+
214+
def _check_multiple_occurence_of_component_in_flow(model, sub_components):
215+
to_visit_stack = []
216+
to_visit_stack.extend(sub_components.values())
217+
known_sub_components = set()
218+
while len(to_visit_stack) > 0:
219+
visitee = to_visit_stack.pop()
220+
if visitee.name in known_sub_components:
221+
raise ValueError('Found a second occurence of component %s when '
222+
'trying to serialize %s.' % (visitee.name, model))
223+
else:
224+
known_sub_components.add(visitee.name)
225+
to_visit_stack.extend(visitee.components.values())
226+
227+
220228
def _extract_information_from_model(model):
221229
# This function contains four "global" states and is quite long and
222230
# complicated. If it gets to complicated to ensure it's correctness,
@@ -257,7 +265,7 @@ def _extract_information_from_model(model):
257265
# Add the component to the list of components, add a
258266
# component reference as a placeholder to the list of
259267
# parameters, which will be replaced by the real component
260-
# when deserealizing the parameter
268+
# when deserializing the parameter
261269
sub_component_identifier = k + '__' + identifier
262270
sub_components_explicit.add(sub_component_identifier)
263271
sub_components[sub_component_identifier] = sub_component

openml/testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def setUp(self):
4242
self.cached = True
4343
# amueller's read/write key that he will throw away later
4444
openml.config.apikey = "610344db6388d9ba34f6db45a3cf71de"
45-
self.production_server = "https://www.openml.org/api/v1/xml"
45+
self.production_server = openml.config.server
4646
self.test_server = "https://test.openml.org/api/v1/xml"
4747
openml.config.server = self.test_server
4848

openml/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
else:
66
from urllib.error import URLError
77

8-
import six
98

109
def is_string(obj):
1110
try:
1211
return isinstance(obj, basestring)
1312
except NameError:
1413
return isinstance(obj, str)
1514

15+
1616
__all__ = ['URLError', 'is_string']
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-
__version__ = 1.0

tests/flows/dummy_learn/dummy_forest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ def get_params(self, deep=False):
99
return {}
1010

1111
def set_params(self, params):
12-
return None
12+
return self

tests/flows/test_flow.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,38 @@
2424
from openml.flows.sklearn_converter import _format_external_version
2525

2626

27+
def are_flows_equal(flow1, flow2):
28+
"""Check equality of two flows.
29+
30+
Two flows are equal if their all keys which are not set by the server
31+
are equal, as well as all their parameters and components.
32+
"""
33+
if not isinstance(flow2, flow1.__class__):
34+
return False
35+
36+
# Name is actually not generated by the server, but it will be
37+
# tested further down with a getter (allows mocking in the tests)
38+
generated_by_the_server = ['name', 'flow_id', 'uploader', 'version',
39+
'upload_date', 'source_url',
40+
'binary_url', 'source_format',
41+
'binary_format', 'source_md5',
42+
'binary_md5', 'model']
43+
44+
for key in set(flow1.__dict__.keys()).union(flow2.__dict__.keys()):
45+
if key in generated_by_the_server:
46+
continue
47+
attr1 = getattr(flow1, key, None)
48+
attr2 = getattr(flow2, key, None)
49+
if key == 'components':
50+
for name in set(attr1.keys()).union(attr2.keys()):
51+
if not are_flows_equal(attr1[name], attr2[name]):
52+
return False
53+
else:
54+
if attr1 != attr2:
55+
return False
56+
return True
57+
58+
2759
def get_sentinel():
2860
# Create a unique prefix for the flow. Necessary because the flow is
2961
# identified by its name and external version online. Having a unique
@@ -110,7 +142,7 @@ def test_to_xml_from_xml(self):
110142
xml = flow._to_xml()
111143
xml_dict = xmltodict.parse(xml)
112144
new_flow = openml.flows.OpenMLFlow._from_dict(xml_dict)
113-
self.assertEqual(new_flow, flow)
145+
self.assertTrue(are_flows_equal(new_flow, flow))
114146
self.assertIsNot(new_flow, flow)
115147

116148
def test_publish_flow(self):
@@ -213,7 +245,7 @@ def test_sklearn_to_upload_to_flow(self):
213245

214246
self.assertEqual(server_xml, local_xml)
215247

216-
self.assertEqual(new_flow, flow)
248+
self.assertTrue(are_flows_equal(new_flow, flow))
217249
self.assertIsNot(new_flow, flow)
218250

219251
fixture_name = 'sklearn.model_selection._search.RandomizedSearchCV(' \

0 commit comments

Comments
 (0)