11import collections
2+ import copy
23import hashlib
34import re
5+ import sys
46import time
5- import random
6- import unittest
77
8- import xmltodict
8+ if sys .version_info [0 ] >= 3 :
9+ from unittest import mock
10+ else :
11+ import mock
912
1013import scipy .stats
1114import sklearn
1922import sklearn .preprocessing
2023import sklearn .naive_bayes
2124import sklearn .tree
25+ import xmltodict
2226
2327from openml .testing import TestBase
2428from openml ._api_calls import _perform_api_call
2529import openml
2630from openml .flows .sklearn_converter import _format_external_version
2731
2832
29- def are_flows_equal (flow1 , flow2 ):
30- """Check equality of two flows.
31-
32- Two flows are equal if their all keys which are not set by the server
33- are equal, as well as all their parameters and components.
34- """
35- if not isinstance (flow2 , flow1 .__class__ ):
36- return False
37-
38- # Name is actually not generated by the server, but it will be
39- # tested further down with a getter (allows mocking in the tests)
40- generated_by_the_server = ['name' , 'flow_id' , 'uploader' , 'version' ,
41- 'upload_date' , 'source_url' ,
42- 'binary_url' , 'source_format' ,
43- 'binary_format' , 'source_md5' ,
44- 'binary_md5' , 'model' ]
45-
46- for key in set (flow1 .__dict__ .keys ()).union (flow2 .__dict__ .keys ()):
47- if key in generated_by_the_server :
48- continue
49- attr1 = getattr (flow1 , key , None )
50- attr2 = getattr (flow2 , key , None )
51- if key == 'components' :
52- for name in set (attr1 .keys ()).union (attr2 .keys ()):
53- if not are_flows_equal (attr1 [name ], attr2 [name ]):
54- return False
55- else :
56- if attr1 != attr2 :
57- return False
58- return True
59-
60-
61- def get_sentinel ():
62- # Create a unique prefix for the flow. Necessary because the flow is
63- # identified by its name and external version online. Having a unique
64- # name allows us to publish the same flow in each test run
65- md5 = hashlib .md5 ()
66- md5 .update (str (time .time ()).encode ('utf-8' ))
67- sentinel = md5 .hexdigest ()[:10 ]
68- sentinel = 'TEST%s' % sentinel
69- return sentinel
70-
71-
7233class TestFlow (TestBase ):
7334
35+
7436 def test_get_flow (self ):
7537 # We need to use the production server here because 4024 is not the test
7638 # server
@@ -134,13 +96,14 @@ def test_to_xml_from_xml(self):
13496 xml = flow ._to_xml ()
13597 xml_dict = xmltodict .parse (xml )
13698 new_flow = openml .flows .OpenMLFlow ._from_dict (xml_dict )
137- self .assertTrue (are_flows_equal (new_flow , flow ))
99+
100+ # Would raise exception if they are not legal
101+ openml .flows .functions .assert_flows_equal (new_flow , flow )
138102 self .assertIsNot (new_flow , flow )
139103
140104 def test_publish_flow (self ):
141- sentinel = get_sentinel ()
142-
143- flow = openml .OpenMLFlow (name = 'Test' ,
105+ flow = openml .OpenMLFlow (name = 'sklearn.dummy.DummyClassifier' ,
106+ class_name = 'sklearn.dummy.DummyClassifier' ,
144107 description = "test description" ,
145108 model = sklearn .dummy .DummyClassifier (),
146109 components = collections .OrderedDict (),
@@ -150,8 +113,9 @@ def test_publish_flow(self):
150113 'sklearn' , sklearn .__version__ ),
151114 tags = [],
152115 language = 'English' ,
153- dependencies = '' )
154- flow .name = 'TEST%s%s' % (sentinel , flow .name )
116+ dependencies = None )
117+
118+ flow , _ = self ._add_sentinel_to_flow_name (flow , None )
155119
156120 flow .publish ()
157121 self .assertIsInstance (flow .flow_id , int )
@@ -160,14 +124,44 @@ def test_semi_legal_flow(self):
160124 # TODO: Test if parameters are set correctly!
161125 # should not throw error as it contains two differentiable forms of Bagging
162126 # i.e., Bagging(Bagging(J48)) and Bagging(J48)
163- sentinel = get_sentinel ()
164127 semi_legal = sklearn .ensemble .BaggingClassifier (
165128 base_estimator = sklearn .ensemble .BaggingClassifier (
166129 base_estimator = sklearn .tree .DecisionTreeClassifier ()))
167130 flow = openml .flows .sklearn_to_flow (semi_legal )
168- flow .name = 'TEST%s%s' % (sentinel , flow .name )
131+ flow , _ = self ._add_sentinel_to_flow_name (flow , None )
132+
133+ flow .publish ()
134+
135+ @mock .patch ('openml.flows.functions.get_flow' )
136+ @mock .patch ('openml.flows.flow._perform_api_call' )
137+ def test_publish_error (self , api_call_mock , get_flow_mock ):
138+ model = sklearn .ensemble .RandomForestClassifier ()
139+ flow = openml .flows .sklearn_to_flow (model )
140+ api_call_mock .return_value = "<oml:upload_flow>\n " \
141+ " <oml:id>1</oml:id>\n " \
142+ "</oml:upload_flow>"
143+ get_flow_mock .return_value = flow
169144
170145 flow .publish ()
146+ self .assertEqual (api_call_mock .call_count , 1 )
147+ self .assertEqual (get_flow_mock .call_count , 1 )
148+
149+ flow_copy = copy .deepcopy (flow )
150+ flow_copy .name = flow_copy .name [:- 1 ]
151+ get_flow_mock .return_value = flow_copy
152+
153+ with self .assertRaises (ValueError ) as context_manager :
154+ flow .publish ()
155+
156+ fixture = "Flow was not stored correctly on the server. " \
157+ "New flow ID is 1. Please check manually and remove " \
158+ "the flow if necessary! Error is:\n " \
159+ "'Flow sklearn.ensemble.forest.RandomForestClassifier: values for attribute 'name' differ: " \
160+ "'sklearn.ensemble.forest.RandomForestClassifier' vs 'sklearn.ensemble.forest.RandomForestClassifie'.'"
161+
162+ self .assertEqual (context_manager .exception .args [0 ], fixture )
163+ self .assertEqual (api_call_mock .call_count , 2 )
164+ self .assertEqual (get_flow_mock .call_count , 2 )
171165
172166 def test_illegal_flow (self ):
173167 # should throw error as it contains two imputers
@@ -177,6 +171,16 @@ def test_illegal_flow(self):
177171 self .assertRaises (ValueError , openml .flows .sklearn_to_flow , illegal )
178172
179173 def test_nonexisting_flow_exists (self ):
174+ def get_sentinel ():
175+ # Create a unique prefix for the flow. Necessary because the flow is
176+ # identified by its name and external version online. Having a unique
177+ # name allows us to publish the same flow in each test run
178+ md5 = hashlib .md5 ()
179+ md5 .update (str (time .time ()).encode ('utf-8' ))
180+ sentinel = md5 .hexdigest ()[:10 ]
181+ sentinel = 'TEST%s' % sentinel
182+ return sentinel
183+
180184 name = get_sentinel () + get_sentinel ()
181185 version = get_sentinel ()
182186
@@ -185,10 +189,9 @@ def test_nonexisting_flow_exists(self):
185189
186190 def test_existing_flow_exists (self ):
187191 # create a flow
188- sentinel = get_sentinel ()
189192 nb = sklearn .naive_bayes .GaussianNB ()
190193 flow = openml .flows .sklearn_to_flow (nb )
191- flow . name = 'TEST%s%s' % ( sentinel , flow . name )
194+ flow , _ = self . _add_sentinel_to_flow_name ( flow , None )
192195 #publish the flow
193196 flow = flow .publish ()
194197 #redownload the flow
@@ -204,7 +207,6 @@ def test_sklearn_to_upload_to_flow(self):
204207 iris = sklearn .datasets .load_iris ()
205208 X = iris .data
206209 y = iris .target
207- sentinel = get_sentinel ()
208210
209211 # Test a more complicated flow
210212 ohe = sklearn .preprocessing .OneHotEncoder (categorical_features = [1 ])
@@ -227,17 +229,7 @@ def test_sklearn_to_upload_to_flow(self):
227229 rs .fit (X , y )
228230 flow = openml .flows .sklearn_to_flow (rs )
229231 flow .tags .extend (['openml-python' , 'unittest' ])
230-
231- # Add the sentinel to all name strings in all subflows. Adds it to
232- # name to make it easier in the web gui to see that the flow is only
233- # a test flow
234- to_visit = collections .deque ()
235- to_visit .appendleft (flow )
236- while len (to_visit ) > 0 :
237- current_flow = to_visit .pop ()
238- for sub_flow in current_flow .components .values ():
239- to_visit .appendleft (sub_flow )
240- current_flow .name = sentinel + current_flow .name
232+ flow , sentinel = self ._add_sentinel_to_flow_name (flow , None )
241233
242234 flow .publish ()
243235 self .assertIsInstance (flow .flow_id , int )
@@ -267,7 +259,8 @@ def test_sklearn_to_upload_to_flow(self):
267259
268260 self .assertEqual (server_xml , local_xml )
269261
270- self .assertTrue (are_flows_equal (new_flow , flow ))
262+ # Would raise exception if they are not equal!
263+ openml .flows .functions .assert_flows_equal (new_flow , flow )
271264 self .assertIsNot (new_flow , flow )
272265
273266 fixture_name = '%ssklearn.model_selection._search.RandomizedSearchCV(' \
0 commit comments