Skip to content

Commit 61fd0ca

Browse files
committed
FIX #145 add check that flows are correctly stored on the server
1 parent a28cb51 commit 61fd0ca

9 files changed

Lines changed: 155 additions & 65 deletions

File tree

openml/flows/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .flow import OpenMLFlow
22
from .sklearn_converter import sklearn_to_flow, flow_to_sklearn
3-
from .functions import get_flow, list_flows, flow_exists
3+
from .functions import get_flow, list_flows, flow_exists, check_flows_equal
44

55
__all__ = ['OpenMLFlow', 'create_flow_from_model', 'get_flow', 'list_flows',
66
'sklearn_to_flow', 'flow_to_sklearn', 'flow_exists']

openml/flows/flow.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,9 +338,24 @@ def publish(self):
338338
file_elements = {'description': xml_description}
339339
return_value = _perform_api_call("flow/", file_elements=file_elements)
340340
self.flow_id = int(xmltodict.parse(return_value)['oml:upload_flow']['oml:id'])
341+
try:
342+
_check_flow(self)
343+
except ValueError as e:
344+
message = e.args[0]
345+
raise ValueError("Flow was not stored correctly on the server. "
346+
"New flow ID is %d. Please check manually and "
347+
"remove the flow if necessary! Error is:\n'%s'" %
348+
(self.flow_id, message))
341349
return self
342350

343351

344352
def _add_if_nonempty(dic, key, value):
345353
if value is not None:
346354
dic[key] = value
355+
356+
357+
def _check_flow(flow):
358+
import openml.flows.functions
359+
360+
flow_copy = openml.flows.functions.get_flow(flow.flow_id)
361+
openml.flows.functions.check_flows_equal(flow, flow_copy)

openml/flows/functions.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,19 @@ def _list_flows(api_call):
130130
return flows
131131

132132

133-
def are_flows_equal(flow1, flow2):
133+
def check_flows_equal(flow1, flow2):
134134
"""Check equality of two flows.
135135
136136
Two flows are equal if their all keys which are not set by the server
137137
are equal, as well as all their parameters and components.
138138
"""
139-
if not isinstance(flow2, flow1.__class__):
140-
return False
139+
if not isinstance(flow1, OpenMLFlow):
140+
raise TypeError('Argument 1 must be of type OpenMLFlow, but is %s' %
141+
type(flow1))
142+
143+
if not isinstance(flow2, OpenMLFlow):
144+
raise TypeError('Argument 2 must be of type OpenMLFlow, but is %s' %
145+
type(flow2))
141146

142147
# Name is actually not generated by the server, but it will be
143148
# tested further down with a getter (allows mocking in the tests)
@@ -153,11 +158,16 @@ def are_flows_equal(flow1, flow2):
153158
attr2 = getattr(flow2, key, None)
154159
if key == 'components':
155160
for name in set(attr1.keys()).union(attr2.keys()):
156-
if not (name in attr1 and name in attr2):
157-
return False
158-
if not are_flows_equal(attr1[name], attr2[name]):
159-
return False
161+
if not name in attr1:
162+
raise ValueError('Component %s only available in '
163+
'argument2, but not in argument1.' % name)
164+
if not name in attr2:
165+
raise ValueError('Component %s only available in '
166+
'argument2, but not in argument1.' % name)
167+
check_flows_equal(attr1[name], attr2[name])
168+
160169
else:
161170
if attr1 != attr2:
162-
return False
163-
return True
171+
raise ValueError("Flow %s: values for attribute '%s' differ: "
172+
"'%s' vs '%s'." %
173+
(str(flow1.name), str(key), str(attr1), str(attr2)))

openml/flows/sklearn_converter.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def flow_to_sklearn(o, **kwargs):
131131
raise ValueError('Cannot flow_to_sklearn %s' % serialized_type)
132132

133133
else:
134-
# Regular dictionary
135134
rval = OrderedDict((flow_to_sklearn(key, **kwargs),
136135
flow_to_sklearn(value, **kwargs))
137136
for key, value in o.items())
@@ -303,8 +302,10 @@ def _extract_information_from_model(model):
303302
component_reference = OrderedDict()
304303
component_reference[
305304
'oml-python:serialized_object'] = 'component_reference'
306-
component_reference['value'] = OrderedDict(
307-
key=identifier, step_name=identifier)
305+
cr_value = OrderedDict()
306+
cr_value['key'] = identifier
307+
cr_value['step_name'] = identifier
308+
component_reference['value'] = cr_value
308309
parameter_value.append(component_reference)
309310

310311
if isinstance(rval, tuple):
@@ -326,7 +327,10 @@ def _extract_information_from_model(model):
326327
component_reference = OrderedDict()
327328
component_reference[
328329
'oml-python:serialized_object'] = 'component_reference'
329-
component_reference['value'] = OrderedDict(key=k, step_name=None)
330+
cr_value = OrderedDict()
331+
cr_value['key'] = k
332+
cr_value['step_name'] = None
333+
component_reference['value'] = cr_value
330334
component_reference = sklearn_to_flow(component_reference, model)
331335
parameters[k] = json.dumps(component_reference)
332336

@@ -387,6 +391,9 @@ def _deserialize_model(flow, **kwargs):
387391

388392

389393
def _check_dependencies(dependencies):
394+
if not dependencies:
395+
return
396+
390397
dependencies = dependencies.split('\n')
391398
for dependency_string in dependencies:
392399
match = DEPENDENCIES_PATTERN.match(dependency_string)
@@ -448,7 +455,8 @@ def serialize_rv_frozen(o):
448455
dist = o.dist.__class__.__module__ + '.' + o.dist.__class__.__name__
449456
ret = OrderedDict()
450457
ret['oml-python:serialized_object'] = 'rv_frozen'
451-
ret['value'] = OrderedDict(dist=dist, a=a, b=b, args=args, kwds=kwds)
458+
ret['value'] = OrderedDict((('dist', dist), ('a', a), ('b', b),
459+
('args', args), ('kwds', kwds)))
452460
return ret
453461

454462
def deserialize_rv_frozen(o, **kwargs):

openml/testing.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import hashlib
12
import inspect
23
import os
4+
import time
35
import shutil
46
import unittest
57
import openml
@@ -54,4 +56,25 @@ def tearDown(self):
5456
shutil.rmtree(self.workdir)
5557
openml.config.server = self.production_server
5658

59+
def _add_sentinel_to_flow_name(self, flow, sentinel=None):
60+
if sentinel is None:
61+
# Create a unique prefix for the flow. Necessary because the flow is
62+
# identified by its name and external version online. Having a unique
63+
# name allows us to publish the same flow in each test run
64+
md5 = hashlib.md5()
65+
md5.update(str(time.time()).encode('utf-8'))
66+
sentinel = md5.hexdigest()[:10]
67+
sentinel = 'TEST%s' % sentinel
68+
69+
flows_to_visit = list()
70+
flows_to_visit.append(flow)
71+
while len(flows_to_visit) > 0:
72+
current_flow = flows_to_visit.pop()
73+
current_flow.name = '%s%s' % (sentinel, current_flow.name)
74+
for subflow in current_flow.components.values():
75+
flows_to_visit.append(subflow)
76+
77+
return flow, sentinel
78+
79+
5780
__all__ = ['TestBase']

tests/test_flows/test_flow.py

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import collections
2+
import copy
23
import hashlib
34
import re
5+
import sys
46
import time
57

6-
import xmltodict
8+
if sys.version_info[0] >= 3:
9+
from unittest import mock
10+
else:
11+
import mock
712

813
import scipy.stats
914
import sklearn
@@ -17,26 +22,17 @@
1722
import sklearn.preprocessing
1823
import sklearn.naive_bayes
1924
import sklearn.tree
25+
import xmltodict
2026

2127
from openml.testing import TestBase
2228
from openml._api_calls import _perform_api_call
2329
import openml
2430
from openml.flows.sklearn_converter import _format_external_version
2531

2632

27-
def get_sentinel():
28-
# Create a unique prefix for the flow. Necessary because the flow is
29-
# identified by its name and external version online. Having a unique
30-
# name allows us to publish the same flow in each test run
31-
md5 = hashlib.md5()
32-
md5.update(str(time.time()).encode('utf-8'))
33-
sentinel = md5.hexdigest()[:10]
34-
sentinel = 'TEST%s' % sentinel
35-
return sentinel
36-
37-
3833
class TestFlow(TestBase):
3934

35+
4036
def test_get_flow(self):
4137
# We need to use the production server here because 4024 is not the test
4238
# server
@@ -100,13 +96,14 @@ def test_to_xml_from_xml(self):
10096
xml = flow._to_xml()
10197
xml_dict = xmltodict.parse(xml)
10298
new_flow = openml.flows.OpenMLFlow._from_dict(xml_dict)
103-
self.assertTrue(openml.flows.functions.are_flows_equal(new_flow, flow))
99+
100+
# Would raise exception if they are not legal
101+
openml.flows.functions.check_flows_equal(new_flow, flow)
104102
self.assertIsNot(new_flow, flow)
105103

106104
def test_publish_flow(self):
107-
sentinel = get_sentinel()
108-
109-
flow = openml.OpenMLFlow(name='Test',
105+
flow = openml.OpenMLFlow(name='sklearn.dummy.DummyClassifier',
106+
class_name='sklearn.dummy.DummyClassifier',
110107
description="test description",
111108
model=sklearn.dummy.DummyClassifier(),
112109
components=collections.OrderedDict(),
@@ -116,8 +113,9 @@ def test_publish_flow(self):
116113
'sklearn', sklearn.__version__),
117114
tags=[],
118115
language='English',
119-
dependencies='')
120-
flow.name = 'TEST%s%s' % (sentinel, flow.name)
116+
dependencies=None)
117+
118+
flow, _ = self._add_sentinel_to_flow_name(flow, None)
121119

122120
flow.publish()
123121
self.assertIsInstance(flow.flow_id, int)
@@ -126,14 +124,44 @@ def test_semi_legal_flow(self):
126124
# TODO: Test if parameters are set correctly!
127125
# should not throw error as it contains two differentiable forms of Bagging
128126
# i.e., Bagging(Bagging(J48)) and Bagging(J48)
129-
sentinel = get_sentinel()
130127
semi_legal = sklearn.ensemble.BaggingClassifier(
131128
base_estimator=sklearn.ensemble.BaggingClassifier(
132129
base_estimator=sklearn.tree.DecisionTreeClassifier()))
133130
flow = openml.flows.sklearn_to_flow(semi_legal)
134-
flow.name = 'TEST%s%s' % (sentinel, flow.name)
131+
flow, _ = self._add_sentinel_to_flow_name(flow, None)
132+
133+
flow.publish()
134+
135+
@mock.patch('openml.flows.functions.get_flow')
136+
@mock.patch('openml.flows.flow._perform_api_call')
137+
def test_publish_error(self, api_call_mock, get_flow_mock):
138+
model = sklearn.ensemble.RandomForestClassifier()
139+
flow = openml.flows.sklearn_to_flow(model)
140+
api_call_mock.return_value = "<oml:upload_flow>\n" \
141+
" <oml:id>1</oml:id>\n" \
142+
"</oml:upload_flow>"
143+
get_flow_mock.return_value = flow
135144

136145
flow.publish()
146+
self.assertEqual(api_call_mock.call_count, 1)
147+
self.assertEqual(get_flow_mock.call_count, 1)
148+
149+
flow_copy = copy.deepcopy(flow)
150+
flow_copy.name = flow_copy.name[:-1]
151+
get_flow_mock.return_value = flow_copy
152+
153+
with self.assertRaises(ValueError) as context_manager:
154+
flow.publish()
155+
156+
fixture = "Flow was not stored correctly on the server. " \
157+
"New flow ID is 1. Please check manually and remove " \
158+
"the flow if necessary! Error is:\n" \
159+
"'Flow sklearn.ensemble.forest.RandomForestClassifier: values for attribute 'name' differ: " \
160+
"'sklearn.ensemble.forest.RandomForestClassifier' vs 'sklearn.ensemble.forest.RandomForestClassifie'.'"
161+
162+
self.assertEqual(context_manager.exception.args[0], fixture)
163+
self.assertEqual(api_call_mock.call_count, 2)
164+
self.assertEqual(get_flow_mock.call_count, 2)
137165

138166
def test_illegal_flow(self):
139167
# should throw error as it contains two imputers
@@ -143,6 +171,16 @@ def test_illegal_flow(self):
143171
self.assertRaises(ValueError, openml.flows.sklearn_to_flow, illegal)
144172

145173
def test_nonexisting_flow_exists(self):
174+
def get_sentinel():
175+
# Create a unique prefix for the flow. Necessary because the flow is
176+
# identified by its name and external version online. Having a unique
177+
# name allows us to publish the same flow in each test run
178+
md5 = hashlib.md5()
179+
md5.update(str(time.time()).encode('utf-8'))
180+
sentinel = md5.hexdigest()[:10]
181+
sentinel = 'TEST%s' % sentinel
182+
return sentinel
183+
146184
name = get_sentinel() + get_sentinel()
147185
version = get_sentinel()
148186

@@ -151,10 +189,9 @@ def test_nonexisting_flow_exists(self):
151189

152190
def test_existing_flow_exists(self):
153191
# create a flow
154-
sentinel = get_sentinel()
155192
nb = sklearn.naive_bayes.GaussianNB()
156193
flow = openml.flows.sklearn_to_flow(nb)
157-
flow.name = 'TEST%s%s' % (sentinel, flow.name)
194+
flow, _ = self._add_sentinel_to_flow_name(flow, None)
158195
#publish the flow
159196
flow = flow.publish()
160197
#redownload the flow
@@ -170,7 +207,6 @@ def test_sklearn_to_upload_to_flow(self):
170207
iris = sklearn.datasets.load_iris()
171208
X = iris.data
172209
y = iris.target
173-
sentinel = get_sentinel()
174210

175211
# Test a more complicated flow
176212
ohe = sklearn.preprocessing.OneHotEncoder(categorical_features=[1])
@@ -193,17 +229,7 @@ def test_sklearn_to_upload_to_flow(self):
193229
rs.fit(X, y)
194230
flow = openml.flows.sklearn_to_flow(rs)
195231
flow.tags.extend(['openml-python', 'unittest'])
196-
197-
# Add the sentinel to all name strings in all subflows. Adds it to
198-
# name to make it easier in the web gui to see that the flow is only
199-
# a test flow
200-
to_visit = collections.deque()
201-
to_visit.appendleft(flow)
202-
while len(to_visit) > 0:
203-
current_flow = to_visit.pop()
204-
for sub_flow in current_flow.components.values():
205-
to_visit.appendleft(sub_flow)
206-
current_flow.name = sentinel + current_flow.name
232+
flow, sentinel = self._add_sentinel_to_flow_name(flow, None)
207233

208234
flow.publish()
209235
self.assertIsInstance(flow.flow_id, int)
@@ -233,7 +259,8 @@ def test_sklearn_to_upload_to_flow(self):
233259

234260
self.assertEqual(server_xml, local_xml)
235261

236-
self.assertTrue(openml.flows.functions.are_flows_equal(new_flow, flow))
262+
# Would raise exception if they are not equal!
263+
openml.flows.functions.check_flows_equal(new_flow, flow)
237264
self.assertIsNot(new_flow, flow)
238265

239266
fixture_name = '%ssklearn.model_selection._search.RandomizedSearchCV(' \

0 commit comments

Comments
 (0)