Skip to content

Commit c46f5b7

Browse files
committed
only check parameters of component if it is not older than the parent flow
1 parent beaa046 commit c46f5b7

3 files changed

Lines changed: 72 additions & 4 deletions

File tree

openml/flows/flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,15 +349,15 @@ def publish(self):
349349
return_value = _perform_api_call("flow/", file_elements=file_elements)
350350
flow_id = int(xmltodict.parse(return_value)['oml:upload_flow']['oml:id'])
351351
flow = openml.flows.functions.get_flow(flow_id)
352+
_copy_server_fields(flow, self)
352353
try:
353-
openml.flows.functions.assert_flows_equal(self, flow)
354+
openml.flows.functions.assert_flows_equal(self, flow, flow.upload_date)
354355
except ValueError as e:
355356
message = e.args[0]
356357
raise ValueError("Flow was not stored correctly on the server. "
357358
"New flow ID is %d. Please check manually and "
358359
"remove the flow if necessary! Error is:\n'%s'" %
359360
(flow_id, message))
360-
_copy_server_fields(flow, self)
361361
return self
362362

363363

openml/flows/functions.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import dateutil.parser
2+
13
import xmltodict
24
import six
35

@@ -143,11 +145,21 @@ def _check_flow_for_server_id(flow):
143145
stack.append(component)
144146

145147

146-
def assert_flows_equal(flow1, flow2):
148+
def assert_flows_equal(flow1, flow2, ignore_parameters_on_older_children=None):
147149
"""Check equality of two flows.
148150
149151
Two flows are equal if their all keys which are not set by the server
150152
are equal, as well as all their parameters and components.
153+
154+
Parameters
155+
----------
156+
flow1 : OpenMLFlow
157+
158+
flow2 : OpenMLFlow
159+
160+
ignore_parameters_on_older_children : str
161+
If set to ``OpenMLFlow.upload_date``, ignores parameters in a child
162+
flow if it's upload date predates the upload date of the parent flow.
151163
"""
152164
if not isinstance(flow1, OpenMLFlow):
153165
raise TypeError('Argument 1 must be of type OpenMLFlow, but is %s' %
@@ -157,6 +169,8 @@ def assert_flows_equal(flow1, flow2):
157169
raise TypeError('Argument 2 must be of type OpenMLFlow, but is %s' %
158170
type(flow2))
159171

172+
# TODO as they are actually now saved during publish, it might be good to
173+
# check for the equality of these as well.
160174
generated_by_the_server = ['flow_id', 'uploader', 'version', 'upload_date']
161175
ignored_by_python_API = ['binary_url', 'binary_format', 'binary_md5',
162176
'model']
@@ -174,9 +188,18 @@ def assert_flows_equal(flow1, flow2):
174188
if not name in attr2:
175189
raise ValueError('Component %s only available in '
176190
'argument2, but not in argument1.' % name)
177-
assert_flows_equal(attr1[name], attr2[name])
191+
assert_flows_equal(attr1[name], attr2[name], ignore_parameters_on_older_children)
178192

179193
else:
194+
if key == 'parameters':
195+
if ignore_parameters_on_older_children:
196+
upload_date_current_flow = dateutil.parser.parse(
197+
flow1.upload_date)
198+
upload_date_parent_flow = dateutil.parser.parse(
199+
ignore_parameters_on_older_children)
200+
if upload_date_current_flow < upload_date_parent_flow:
201+
continue
202+
180203
if attr1 != attr2:
181204
raise ValueError("Flow %s: values for attribute '%s' differ: "
182205
"'%s' vs '%s'." %

tests/test_flows/test_flow.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from openml._api_calls import _perform_api_call
2929
import openml
3030
from openml.flows.sklearn_converter import _format_external_version
31+
import openml.exceptions
3132

3233

3334
class TestFlow(TestBase):
@@ -120,6 +121,50 @@ def test_publish_flow(self):
120121
flow.publish()
121122
self.assertIsInstance(flow.flow_id, int)
122123

124+
def test_publish_existing_flow(self):
125+
clf = sklearn.tree.DecisionTreeClassifier(max_depth=2)
126+
flow = openml.flows.sklearn_to_flow(clf)
127+
flow, _ = self._add_sentinel_to_flow_name(flow, None)
128+
flow.publish()
129+
self.assertRaisesRegexp(openml.exceptions.OpenMLServerException,
130+
'flow already exists', flow.publish)
131+
132+
def test_publish_flow_with_similar_components(self):
133+
clf = sklearn.ensemble.VotingClassifier(
134+
[('lr', sklearn.linear_model.LogisticRegression())])
135+
flow = openml.flows.sklearn_to_flow(clf)
136+
flow, _ = self._add_sentinel_to_flow_name(flow, None)
137+
flow.publish()
138+
# For a flow where both components are published together, the upload
139+
# date should be equal
140+
self.assertEqual(flow.upload_date,
141+
flow.components['lr'].upload_date,
142+
(flow.name, flow.flow_id,
143+
flow.components['lr'].name, flow.components['lr'].flow_id))
144+
145+
clf1 = sklearn.tree.DecisionTreeClassifier(max_depth=2)
146+
flow1 = openml.flows.sklearn_to_flow(clf1)
147+
flow1, sentinel = self._add_sentinel_to_flow_name(flow1, None)
148+
flow1.publish()
149+
150+
clf2 = sklearn.ensemble.VotingClassifier(
151+
[('dt', sklearn.tree.DecisionTreeClassifier(max_depth=2))])
152+
flow2 = openml.flows.sklearn_to_flow(clf2)
153+
flow2, _ = self._add_sentinel_to_flow_name(flow2, sentinel)
154+
flow2.publish()
155+
# If one component was published before the other, the components in
156+
# the flow should have different upload dates
157+
self.assertNotEqual(flow2.upload_date,
158+
flow2.components['dt'].upload_date)
159+
160+
clf3 = sklearn.ensemble.AdaBoostClassifier(
161+
sklearn.tree.DecisionTreeClassifier(max_depth=3))
162+
flow3 = openml.flows.sklearn_to_flow(clf3)
163+
flow3, _ = self._add_sentinel_to_flow_name(flow3, sentinel)
164+
# Child flow has different parameter. Check for storing the flow
165+
# correctly on the server should thus not check the child's parameters!
166+
flow3.publish()
167+
123168
def test_semi_legal_flow(self):
124169
# TODO: Test if parameters are set correctly!
125170
# should not throw error as it contains two differentiable forms of Bagging

0 commit comments

Comments
 (0)