Skip to content

Commit 363b381

Browse files
committed
fix unit tests for setup
1 parent b8ced46 commit 363b381

2 files changed

Lines changed: 18 additions & 13 deletions

File tree

openml/runs/run.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,19 @@ def get_flow_dict(_flow):
185185
flow_map.update(get_flow_dict(_flow.components[subflow]))
186186
return flow_map
187187

188-
def extract_parameters(_flow, _param_dict):
188+
def extract_parameters(_flow, _param_dict, _main_call=False, main_id=None):
189+
# _flow is openml flow object, _param dict maps from flow name to flow id
190+
# for the main call, the param dict can be overridden (useful for unit tests / sentinels)
191+
# this way, for flows without subflows we do not have to rely on _param_dict
189192
_params = []
190193
for _param_name in _flow.parameters:
191194
_current = OrderedDict()
192195
_current['oml:name'] = _param_name
193196
_current['oml:value'] = _flow.parameters[_param_name]
194-
_current['oml:component'] = _param_dict[_flow.name]
197+
if _main_call:
198+
_current['oml:component'] = main_id
199+
else:
200+
_current['oml:component'] = _param_dict[_flow.name]
195201
_params.append(_current)
196202
for _identifier in _flow.components:
197203
_params.extend(extract_parameters(_flow.components[_identifier], _param_dict))
@@ -200,7 +206,7 @@ def extract_parameters(_flow, _param_dict):
200206
flow_dict = get_flow_dict(server_flow)
201207
local_flow = openml.flows.sklearn_to_flow(model)
202208

203-
parameters = extract_parameters(local_flow, flow_dict)
209+
parameters = extract_parameters(local_flow, flow_dict, True, server_flow.flow_id)
204210
return parameters
205211

206212
################################################################################

tests/test_setups/test_setup_functions.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class TestRun(TestBase):
3232
def test_nonexisting_setup_exists(self):
3333
# first publish a non-existing flow
3434
sentinel = get_sentinel()
35+
# because of the sentinel, we can not use flows that contain subflows
3536
dectree = DecisionTreeClassifier()
3637
flow = openml.flows.sklearn_to_flow(dectree)
3738
flow.name = 'TEST%s%s' % (sentinel, flow.name)
@@ -45,34 +46,32 @@ def test_nonexisting_setup_exists(self):
4546

4647
def test_existing_setup_exists(self):
4748
# first publish a nonexiting flow
48-
bagging = BaggingClassifier(DecisionTreeClassifier(max_depth=5,
49-
min_samples_split=3),
50-
n_estimators=3,
51-
max_samples=0.5)
52-
flow = openml.flows.sklearn_to_flow(bagging)
49+
50+
# because of the sentinel, we can not use flows that contain subflows
51+
classif = DecisionTreeClassifier(max_depth=5,
52+
min_samples_split=3)
53+
flow = openml.flows.sklearn_to_flow(classif)
5354
flow.name = 'TEST%s%s' % (get_sentinel(), flow.name)
54-
flow.components['base_estimator'].name = 'TEST%s%s' % (
55-
get_sentinel(), flow.components['base_estimator'].name)
5655

5756
flow = flow.publish()
5857
flow = openml.flows.get_flow(flow.flow_id)
5958

6059
# although the flow exists, we can be sure there are no
6160
# setups (yet) as it hasn't been ran
62-
setup_id = openml.setups.setup_exists(flow, bagging)
61+
setup_id = openml.setups.setup_exists(flow, classif)
6362
self.assertFalse(setup_id)
6463

6564
# now run the flow on an easy task:
6665
task = openml.tasks.get_task(115) #diabetes
67-
run = openml.runs.run_task(task, bagging)
66+
run = openml.runs.run_task(task, classif)
6867
# spoof flow id, otherwise the sentinel is ignored
6968
run.flow_id = flow.flow_id
7069
run = run.publish()
7170
# download the run, as it contains the right setup id
7271
run = openml.runs.get_run(run.run_id)
7372

7473
# execute the function we are interested in
75-
setup_id = openml.setups.setup_exists(flow, bagging)
74+
setup_id = openml.setups.setup_exists(flow, classif)
7675
self.assertEquals(setup_id, run.setup_id)
7776

7877
def test_setup_get(self):

0 commit comments

Comments
 (0)