11import collections
2+ import copy
23import hashlib
34import re
5+ import sys
46import time
57
6- import xmltodict
8+ if sys .version_info [0 ] >= 3 :
9+ from unittest import mock
10+ else :
11+ import mock
712
813import scipy .stats
914import sklearn
1722import sklearn .preprocessing
1823import sklearn .naive_bayes
1924import sklearn .tree
25+ import xmltodict
2026
2127from openml .testing import TestBase
2228from openml ._api_calls import _perform_api_call
2329import openml
2430from openml .flows .sklearn_converter import _format_external_version
2531
2632
27- def get_sentinel ():
28- # Create a unique prefix for the flow. Necessary because the flow is
29- # identified by its name and external version online. Having a unique
30- # name allows us to publish the same flow in each test run
31- md5 = hashlib .md5 ()
32- md5 .update (str (time .time ()).encode ('utf-8' ))
33- sentinel = md5 .hexdigest ()[:10 ]
34- sentinel = 'TEST%s' % sentinel
35- return sentinel
36-
37-
3833class TestFlow (TestBase ):
3934
35+
4036 def test_get_flow (self ):
4137 # We need to use the production server here because 4024 is not the test
4238 # server
@@ -100,13 +96,14 @@ def test_to_xml_from_xml(self):
10096 xml = flow ._to_xml ()
10197 xml_dict = xmltodict .parse (xml )
10298 new_flow = openml .flows .OpenMLFlow ._from_dict (xml_dict )
103- self .assertTrue (openml .flows .functions .are_flows_equal (new_flow , flow ))
99+
100+ # Would raise exception if they are not legal
101+ openml .flows .functions .check_flows_equal (new_flow , flow )
104102 self .assertIsNot (new_flow , flow )
105103
106104 def test_publish_flow (self ):
107- sentinel = get_sentinel ()
108-
109- flow = openml .OpenMLFlow (name = 'Test' ,
105+ flow = openml .OpenMLFlow (name = 'sklearn.dummy.DummyClassifier' ,
106+ class_name = 'sklearn.dummy.DummyClassifier' ,
110107 description = "test description" ,
111108 model = sklearn .dummy .DummyClassifier (),
112109 components = collections .OrderedDict (),
@@ -116,8 +113,9 @@ def test_publish_flow(self):
116113 'sklearn' , sklearn .__version__ ),
117114 tags = [],
118115 language = 'English' ,
119- dependencies = '' )
120- flow .name = 'TEST%s%s' % (sentinel , flow .name )
116+ dependencies = None )
117+
118+ flow , _ = self ._add_sentinel_to_flow_name (flow , None )
121119
122120 flow .publish ()
123121 self .assertIsInstance (flow .flow_id , int )
@@ -126,14 +124,44 @@ def test_semi_legal_flow(self):
126124 # TODO: Test if parameters are set correctly!
127125 # should not throw error as it contains two differentiable forms of Bagging
128126 # i.e., Bagging(Bagging(J48)) and Bagging(J48)
129- sentinel = get_sentinel ()
130127 semi_legal = sklearn .ensemble .BaggingClassifier (
131128 base_estimator = sklearn .ensemble .BaggingClassifier (
132129 base_estimator = sklearn .tree .DecisionTreeClassifier ()))
133130 flow = openml .flows .sklearn_to_flow (semi_legal )
134- flow .name = 'TEST%s%s' % (sentinel , flow .name )
131+ flow , _ = self ._add_sentinel_to_flow_name (flow , None )
132+
133+ flow .publish ()
134+
135+ @mock .patch ('openml.flows.functions.get_flow' )
136+ @mock .patch ('openml.flows.flow._perform_api_call' )
137+ def test_publish_error (self , api_call_mock , get_flow_mock ):
138+ model = sklearn .ensemble .RandomForestClassifier ()
139+ flow = openml .flows .sklearn_to_flow (model )
140+ api_call_mock .return_value = "<oml:upload_flow>\n " \
141+ " <oml:id>1</oml:id>\n " \
142+ "</oml:upload_flow>"
143+ get_flow_mock .return_value = flow
135144
136145 flow .publish ()
146+ self .assertEqual (api_call_mock .call_count , 1 )
147+ self .assertEqual (get_flow_mock .call_count , 1 )
148+
149+ flow_copy = copy .deepcopy (flow )
150+ flow_copy .name = flow_copy .name [:- 1 ]
151+ get_flow_mock .return_value = flow_copy
152+
153+ with self .assertRaises (ValueError ) as context_manager :
154+ flow .publish ()
155+
156+ fixture = "Flow was not stored correctly on the server. " \
157+ "New flow ID is 1. Please check manually and remove " \
158+ "the flow if necessary! Error is:\n " \
159+ "'Flow sklearn.ensemble.forest.RandomForestClassifier: values for attribute 'name' differ: " \
160+ "'sklearn.ensemble.forest.RandomForestClassifier' vs 'sklearn.ensemble.forest.RandomForestClassifie'.'"
161+
162+ self .assertEqual (context_manager .exception .args [0 ], fixture )
163+ self .assertEqual (api_call_mock .call_count , 2 )
164+ self .assertEqual (get_flow_mock .call_count , 2 )
137165
138166 def test_illegal_flow (self ):
139167 # should throw error as it contains two imputers
@@ -143,6 +171,16 @@ def test_illegal_flow(self):
143171 self .assertRaises (ValueError , openml .flows .sklearn_to_flow , illegal )
144172
145173 def test_nonexisting_flow_exists (self ):
174+ def get_sentinel ():
175+ # Create a unique prefix for the flow. Necessary because the flow is
176+ # identified by its name and external version online. Having a unique
177+ # name allows us to publish the same flow in each test run
178+ md5 = hashlib .md5 ()
179+ md5 .update (str (time .time ()).encode ('utf-8' ))
180+ sentinel = md5 .hexdigest ()[:10 ]
181+ sentinel = 'TEST%s' % sentinel
182+ return sentinel
183+
146184 name = get_sentinel () + get_sentinel ()
147185 version = get_sentinel ()
148186
@@ -151,10 +189,9 @@ def test_nonexisting_flow_exists(self):
151189
152190 def test_existing_flow_exists (self ):
153191 # create a flow
154- sentinel = get_sentinel ()
155192 nb = sklearn .naive_bayes .GaussianNB ()
156193 flow = openml .flows .sklearn_to_flow (nb )
157- flow . name = 'TEST%s%s' % ( sentinel , flow . name )
194+ flow , _ = self . _add_sentinel_to_flow_name ( flow , None )
158195 #publish the flow
159196 flow = flow .publish ()
160197 #redownload the flow
@@ -170,7 +207,6 @@ def test_sklearn_to_upload_to_flow(self):
170207 iris = sklearn .datasets .load_iris ()
171208 X = iris .data
172209 y = iris .target
173- sentinel = get_sentinel ()
174210
175211 # Test a more complicated flow
176212 ohe = sklearn .preprocessing .OneHotEncoder (categorical_features = [1 ])
@@ -193,17 +229,7 @@ def test_sklearn_to_upload_to_flow(self):
193229 rs .fit (X , y )
194230 flow = openml .flows .sklearn_to_flow (rs )
195231 flow .tags .extend (['openml-python' , 'unittest' ])
196-
197- # Add the sentinel to all name strings in all subflows. Adds it to
198- # name to make it easier in the web gui to see that the flow is only
199- # a test flow
200- to_visit = collections .deque ()
201- to_visit .appendleft (flow )
202- while len (to_visit ) > 0 :
203- current_flow = to_visit .pop ()
204- for sub_flow in current_flow .components .values ():
205- to_visit .appendleft (sub_flow )
206- current_flow .name = sentinel + current_flow .name
232+ flow , sentinel = self ._add_sentinel_to_flow_name (flow , None )
207233
208234 flow .publish ()
209235 self .assertIsInstance (flow .flow_id , int )
@@ -233,7 +259,8 @@ def test_sklearn_to_upload_to_flow(self):
233259
234260 self .assertEqual (server_xml , local_xml )
235261
236- self .assertTrue (openml .flows .functions .are_flows_equal (new_flow , flow ))
262+ # Would raise exception if they are not equal!
263+ openml .flows .functions .check_flows_equal (new_flow , flow )
237264 self .assertIsNot (new_flow , flow )
238265
239266 fixture_name = '%ssklearn.model_selection._search.RandomizedSearchCV(' \
0 commit comments