Skip to content

Commit 3ae5f8e

Browse files
committed
requests from @mfeurer
1 parent 363b381 commit 3ae5f8e

3 files changed

Lines changed: 9 additions & 1 deletion

File tree

openml/runs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .run import OpenMLRun
2-
from .functions import (run_task, get_run, list_runs, get_runs)
2+
from .functions import (run_task, get_run, list_runs, get_runs, initialize_model_from_run)
33

44
__all__ = ['OpenMLRun', 'run_task', 'get_run', 'list_runs', 'get_runs']

openml/runs/functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def _get_seeded_model(model, seed=None):
160160
if 'random_state' in param_name:
161161
currentValue = model_params[param_name]
162162
# important to draw the value at this point (and not in the if statement)
163+
# this way we guarantee that if a different set of subflows is seeded,
164+
# the same number of the random generator is used
163165
newValue = rs.randint(0, 2**16)
164166
if currentValue is None:
165167
random_states[param_name] = newValue

tests/test_runs/test_run_functions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def _perform_run(self, task_id, num_instances, clf, check_setup=True):
3737
self.assertEqual(len(run.data_content), num_instances)
3838

3939
if check_setup:
40+
# test the initialize setup function
4041
run_id = run_.run_id
4142
run_server = openml.runs.get_run(run_id)
4243
clf_server = openml.setups.initialize_model(run_server.setup_id)
@@ -46,6 +47,11 @@ def _perform_run(self, task_id, num_instances, clf, check_setup=True):
4647

4748
openml.flows.assert_flows_equal(flow_local, flow_server)
4849

50+
# and test the initialize setup from run function
51+
clf_server2 = openml.runs.initialize_model_from_run(run_server.run_id)
52+
flow_server2 = openml.flows.sklearn_to_flow(clf_server2)
53+
openml.flows.assert_flows_equal(flow_local, flow_server2)
54+
4955
#self.assertEquals(clf.get_params(), clf_prime.get_params())
5056
# self.assertEquals(clf, clf_prime)
5157

0 commit comments

Comments
 (0)