Skip to content

Commit 512c07a

Browse files
authored
Merge pull request #236 from openml/add/#145
WIP Add a check that flows are correctly stored on the server
2 parents 9a0d9a8 + 523b603 commit 512c07a

9 files changed

Lines changed: 237 additions & 80 deletions

File tree

openml/flows/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .flow import OpenMLFlow
2+
23
from .sklearn_converter import sklearn_to_flow, flow_to_sklearn, _check_n_jobs
3-
from .functions import get_flow, list_flows, flow_exists
4+
from .functions import get_flow, list_flows, flow_exists, assert_flows_equal
45

56
__all__ = ['OpenMLFlow', 'create_flow_from_model', 'get_flow', 'list_flows',
67
'sklearn_to_flow', 'flow_to_sklearn', 'flow_exists']

openml/flows/flow.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,9 +338,26 @@ def publish(self):
338338
file_elements = {'description': xml_description}
339339
return_value = _perform_api_call("flow/", file_elements=file_elements)
340340
self.flow_id = int(xmltodict.parse(return_value)['oml:upload_flow']['oml:id'])
341+
try:
342+
_check_flow(self)
343+
except ValueError as e:
344+
message = e.args[0]
345+
raise ValueError("Flow was not stored correctly on the server. "
346+
"New flow ID is %d. Please check manually and "
347+
"remove the flow if necessary! Error is:\n'%s'" %
348+
(self.flow_id, message))
341349
return self
342350

343351

344352
def _add_if_nonempty(dic, key, value):
345353
if value is not None:
346354
dic[key] = value
355+
356+
357+
def _check_flow(flow):
358+
# Import is not possible at the top of the file as this would cause an
359+
# ImportError due to an import cycle.
360+
import openml.flows.functions
361+
362+
flow_copy = openml.flows.functions.get_flow(flow.flow_id)
363+
openml.flows.functions.assert_flows_equal(flow, flow_copy)

openml/flows/functions.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,45 @@ def _list_flows(api_call):
127127
'uploader': flow_['oml:uploader']}
128128
flows[fid] = flow
129129

130-
return flows
130+
return flows
131+
132+
133+
def assert_flows_equal(flow1, flow2):
134+
"""Check equality of two flows.
135+
136+
Two flows are equal if their all keys which are not set by the server
137+
are equal, as well as all their parameters and components.
138+
"""
139+
if not isinstance(flow1, OpenMLFlow):
140+
raise TypeError('Argument 1 must be of type OpenMLFlow, but is %s' %
141+
type(flow1))
142+
143+
if not isinstance(flow2, OpenMLFlow):
144+
raise TypeError('Argument 2 must be of type OpenMLFlow, but is %s' %
145+
type(flow2))
146+
147+
generated_by_the_server = ['flow_id', 'uploader', 'version',
148+
'upload_date', ]
149+
ignored_by_python_API = ['binary_url', 'binary_format', 'binary_md5',
150+
'model']
151+
152+
for key in set(flow1.__dict__.keys()).union(flow2.__dict__.keys()):
153+
if key in generated_by_the_server + ignored_by_python_API:
154+
continue
155+
attr1 = getattr(flow1, key, None)
156+
attr2 = getattr(flow2, key, None)
157+
if key == 'components':
158+
for name in set(attr1.keys()).union(attr2.keys()):
159+
if not name in attr1:
160+
raise ValueError('Component %s only available in '
161+
'argument2, but not in argument1.' % name)
162+
if not name in attr2:
163+
raise ValueError('Component %s only available in '
164+
'argument2, but not in argument1.' % name)
165+
assert_flows_equal(attr1[name], attr2[name])
166+
167+
else:
168+
if attr1 != attr2:
169+
raise ValueError("Flow %s: values for attribute '%s' differ: "
170+
"'%s' vs '%s'." %
171+
(str(flow1.name), str(key), str(attr1), str(attr2)))

openml/flows/sklearn_converter.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def flow_to_sklearn(o, **kwargs):
131131
raise ValueError('Cannot flow_to_sklearn %s' % serialized_type)
132132

133133
else:
134-
# Regular dictionary
135134
rval = OrderedDict((flow_to_sklearn(key, **kwargs),
136135
flow_to_sklearn(value, **kwargs))
137136
for key, value in o.items())
@@ -303,8 +302,10 @@ def _extract_information_from_model(model):
303302
component_reference = OrderedDict()
304303
component_reference[
305304
'oml-python:serialized_object'] = 'component_reference'
306-
component_reference['value'] = OrderedDict(
307-
key=identifier, step_name=identifier)
305+
cr_value = OrderedDict()
306+
cr_value['key'] = identifier
307+
cr_value['step_name'] = identifier
308+
component_reference['value'] = cr_value
308309
parameter_value.append(component_reference)
309310

310311
if isinstance(rval, tuple):
@@ -326,7 +327,10 @@ def _extract_information_from_model(model):
326327
component_reference = OrderedDict()
327328
component_reference[
328329
'oml-python:serialized_object'] = 'component_reference'
329-
component_reference['value'] = OrderedDict(key=k, step_name=None)
330+
cr_value = OrderedDict()
331+
cr_value['key'] = k
332+
cr_value['step_name'] = None
333+
component_reference['value'] = cr_value
330334
component_reference = sklearn_to_flow(component_reference, model)
331335
parameters[k] = json.dumps(component_reference)
332336

@@ -387,6 +391,9 @@ def _deserialize_model(flow, **kwargs):
387391

388392

389393
def _check_dependencies(dependencies):
394+
if not dependencies:
395+
return
396+
390397
dependencies = dependencies.split('\n')
391398
for dependency_string in dependencies:
392399
match = DEPENDENCIES_PATTERN.match(dependency_string)
@@ -448,7 +455,8 @@ def serialize_rv_frozen(o):
448455
dist = o.dist.__class__.__module__ + '.' + o.dist.__class__.__name__
449456
ret = OrderedDict()
450457
ret['oml-python:serialized_object'] = 'rv_frozen'
451-
ret['value'] = OrderedDict(dist=dist, a=a, b=b, args=args, kwds=kwds)
458+
ret['value'] = OrderedDict((('dist', dist), ('a', a), ('b', b),
459+
('args', args), ('kwds', kwds)))
452460
return ret
453461

454462
def deserialize_rv_frozen(o, **kwargs):

openml/testing.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import hashlib
12
import inspect
23
import os
4+
import time
35
import shutil
46
import unittest
57
import openml
@@ -54,4 +56,25 @@ def tearDown(self):
5456
shutil.rmtree(self.workdir)
5557
openml.config.server = self.production_server
5658

59+
def _add_sentinel_to_flow_name(self, flow, sentinel=None):
60+
if sentinel is None:
61+
# Create a unique prefix for the flow. Necessary because the flow is
62+
# identified by its name and external version online. Having a unique
63+
# name allows us to publish the same flow in each test run
64+
md5 = hashlib.md5()
65+
md5.update(str(time.time()).encode('utf-8'))
66+
sentinel = md5.hexdigest()[:10]
67+
sentinel = 'TEST%s' % sentinel
68+
69+
flows_to_visit = list()
70+
flows_to_visit.append(flow)
71+
while len(flows_to_visit) > 0:
72+
current_flow = flows_to_visit.pop()
73+
current_flow.name = '%s%s' % (sentinel, current_flow.name)
74+
for subflow in current_flow.components.values():
75+
flows_to_visit.append(subflow)
76+
77+
return flow, sentinel
78+
79+
5780
__all__ = ['TestBase']

tests/test_flows/test_flow.py

Lines changed: 62 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import collections
2+
import copy
23
import hashlib
34
import re
5+
import sys
46
import time
5-
import random
6-
import unittest
77

8-
import xmltodict
8+
if sys.version_info[0] >= 3:
9+
from unittest import mock
10+
else:
11+
import mock
912

1013
import scipy.stats
1114
import sklearn
@@ -19,58 +22,17 @@
1922
import sklearn.preprocessing
2023
import sklearn.naive_bayes
2124
import sklearn.tree
25+
import xmltodict
2226

2327
from openml.testing import TestBase
2428
from openml._api_calls import _perform_api_call
2529
import openml
2630
from openml.flows.sklearn_converter import _format_external_version
2731

2832

29-
def are_flows_equal(flow1, flow2):
30-
"""Check equality of two flows.
31-
32-
Two flows are equal if their all keys which are not set by the server
33-
are equal, as well as all their parameters and components.
34-
"""
35-
if not isinstance(flow2, flow1.__class__):
36-
return False
37-
38-
# Name is actually not generated by the server, but it will be
39-
# tested further down with a getter (allows mocking in the tests)
40-
generated_by_the_server = ['name', 'flow_id', 'uploader', 'version',
41-
'upload_date', 'source_url',
42-
'binary_url', 'source_format',
43-
'binary_format', 'source_md5',
44-
'binary_md5', 'model']
45-
46-
for key in set(flow1.__dict__.keys()).union(flow2.__dict__.keys()):
47-
if key in generated_by_the_server:
48-
continue
49-
attr1 = getattr(flow1, key, None)
50-
attr2 = getattr(flow2, key, None)
51-
if key == 'components':
52-
for name in set(attr1.keys()).union(attr2.keys()):
53-
if not are_flows_equal(attr1[name], attr2[name]):
54-
return False
55-
else:
56-
if attr1 != attr2:
57-
return False
58-
return True
59-
60-
61-
def get_sentinel():
62-
# Create a unique prefix for the flow. Necessary because the flow is
63-
# identified by its name and external version online. Having a unique
64-
# name allows us to publish the same flow in each test run
65-
md5 = hashlib.md5()
66-
md5.update(str(time.time()).encode('utf-8'))
67-
sentinel = md5.hexdigest()[:10]
68-
sentinel = 'TEST%s' % sentinel
69-
return sentinel
70-
71-
7233
class TestFlow(TestBase):
7334

35+
7436
def test_get_flow(self):
7537
# We need to use the production server here because 4024 is not the test
7638
# server
@@ -134,13 +96,14 @@ def test_to_xml_from_xml(self):
13496
xml = flow._to_xml()
13597
xml_dict = xmltodict.parse(xml)
13698
new_flow = openml.flows.OpenMLFlow._from_dict(xml_dict)
137-
self.assertTrue(are_flows_equal(new_flow, flow))
99+
100+
# Would raise exception if they are not legal
101+
openml.flows.functions.assert_flows_equal(new_flow, flow)
138102
self.assertIsNot(new_flow, flow)
139103

140104
def test_publish_flow(self):
141-
sentinel = get_sentinel()
142-
143-
flow = openml.OpenMLFlow(name='Test',
105+
flow = openml.OpenMLFlow(name='sklearn.dummy.DummyClassifier',
106+
class_name='sklearn.dummy.DummyClassifier',
144107
description="test description",
145108
model=sklearn.dummy.DummyClassifier(),
146109
components=collections.OrderedDict(),
@@ -150,8 +113,9 @@ def test_publish_flow(self):
150113
'sklearn', sklearn.__version__),
151114
tags=[],
152115
language='English',
153-
dependencies='')
154-
flow.name = 'TEST%s%s' % (sentinel, flow.name)
116+
dependencies=None)
117+
118+
flow, _ = self._add_sentinel_to_flow_name(flow, None)
155119

156120
flow.publish()
157121
self.assertIsInstance(flow.flow_id, int)
@@ -160,14 +124,44 @@ def test_semi_legal_flow(self):
160124
# TODO: Test if parameters are set correctly!
161125
# should not throw error as it contains two differentiable forms of Bagging
162126
# i.e., Bagging(Bagging(J48)) and Bagging(J48)
163-
sentinel = get_sentinel()
164127
semi_legal = sklearn.ensemble.BaggingClassifier(
165128
base_estimator=sklearn.ensemble.BaggingClassifier(
166129
base_estimator=sklearn.tree.DecisionTreeClassifier()))
167130
flow = openml.flows.sklearn_to_flow(semi_legal)
168-
flow.name = 'TEST%s%s' % (sentinel, flow.name)
131+
flow, _ = self._add_sentinel_to_flow_name(flow, None)
132+
133+
flow.publish()
134+
135+
@mock.patch('openml.flows.functions.get_flow')
136+
@mock.patch('openml.flows.flow._perform_api_call')
137+
def test_publish_error(self, api_call_mock, get_flow_mock):
138+
model = sklearn.ensemble.RandomForestClassifier()
139+
flow = openml.flows.sklearn_to_flow(model)
140+
api_call_mock.return_value = "<oml:upload_flow>\n" \
141+
" <oml:id>1</oml:id>\n" \
142+
"</oml:upload_flow>"
143+
get_flow_mock.return_value = flow
169144

170145
flow.publish()
146+
self.assertEqual(api_call_mock.call_count, 1)
147+
self.assertEqual(get_flow_mock.call_count, 1)
148+
149+
flow_copy = copy.deepcopy(flow)
150+
flow_copy.name = flow_copy.name[:-1]
151+
get_flow_mock.return_value = flow_copy
152+
153+
with self.assertRaises(ValueError) as context_manager:
154+
flow.publish()
155+
156+
fixture = "Flow was not stored correctly on the server. " \
157+
"New flow ID is 1. Please check manually and remove " \
158+
"the flow if necessary! Error is:\n" \
159+
"'Flow sklearn.ensemble.forest.RandomForestClassifier: values for attribute 'name' differ: " \
160+
"'sklearn.ensemble.forest.RandomForestClassifier' vs 'sklearn.ensemble.forest.RandomForestClassifie'.'"
161+
162+
self.assertEqual(context_manager.exception.args[0], fixture)
163+
self.assertEqual(api_call_mock.call_count, 2)
164+
self.assertEqual(get_flow_mock.call_count, 2)
171165

172166
def test_illegal_flow(self):
173167
# should throw error as it contains two imputers
@@ -177,6 +171,16 @@ def test_illegal_flow(self):
177171
self.assertRaises(ValueError, openml.flows.sklearn_to_flow, illegal)
178172

179173
def test_nonexisting_flow_exists(self):
174+
def get_sentinel():
175+
# Create a unique prefix for the flow. Necessary because the flow is
176+
# identified by its name and external version online. Having a unique
177+
# name allows us to publish the same flow in each test run
178+
md5 = hashlib.md5()
179+
md5.update(str(time.time()).encode('utf-8'))
180+
sentinel = md5.hexdigest()[:10]
181+
sentinel = 'TEST%s' % sentinel
182+
return sentinel
183+
180184
name = get_sentinel() + get_sentinel()
181185
version = get_sentinel()
182186

@@ -185,10 +189,9 @@ def test_nonexisting_flow_exists(self):
185189

186190
def test_existing_flow_exists(self):
187191
# create a flow
188-
sentinel = get_sentinel()
189192
nb = sklearn.naive_bayes.GaussianNB()
190193
flow = openml.flows.sklearn_to_flow(nb)
191-
flow.name = 'TEST%s%s' % (sentinel, flow.name)
194+
flow, _ = self._add_sentinel_to_flow_name(flow, None)
192195
#publish the flow
193196
flow = flow.publish()
194197
#redownload the flow
@@ -204,7 +207,6 @@ def test_sklearn_to_upload_to_flow(self):
204207
iris = sklearn.datasets.load_iris()
205208
X = iris.data
206209
y = iris.target
207-
sentinel = get_sentinel()
208210

209211
# Test a more complicated flow
210212
ohe = sklearn.preprocessing.OneHotEncoder(categorical_features=[1])
@@ -227,17 +229,7 @@ def test_sklearn_to_upload_to_flow(self):
227229
rs.fit(X, y)
228230
flow = openml.flows.sklearn_to_flow(rs)
229231
flow.tags.extend(['openml-python', 'unittest'])
230-
231-
# Add the sentinel to all name strings in all subflows. Adds it to
232-
# name to make it easier in the web gui to see that the flow is only
233-
# a test flow
234-
to_visit = collections.deque()
235-
to_visit.appendleft(flow)
236-
while len(to_visit) > 0:
237-
current_flow = to_visit.pop()
238-
for sub_flow in current_flow.components.values():
239-
to_visit.appendleft(sub_flow)
240-
current_flow.name = sentinel + current_flow.name
232+
flow, sentinel = self._add_sentinel_to_flow_name(flow, None)
241233

242234
flow.publish()
243235
self.assertIsInstance(flow.flow_id, int)
@@ -267,7 +259,8 @@ def test_sklearn_to_upload_to_flow(self):
267259

268260
self.assertEqual(server_xml, local_xml)
269261

270-
self.assertTrue(are_flows_equal(new_flow, flow))
262+
# Would raise exception if they are not equal!
263+
openml.flows.functions.assert_flows_equal(new_flow, flow)
271264
self.assertIsNot(new_flow, flow)
272265

273266
fixture_name = '%ssklearn.model_selection._search.RandomizedSearchCV(' \

0 commit comments

Comments
 (0)