1313
1414from ..exceptions import PyOpenMLError
1515from .. import config
16- from ..flows import sklearn_to_flow , get_flow , flow_exists , _check_n_jobs
16+ from ..flows import sklearn_to_flow , get_flow , flow_exists , _check_n_jobs , \
17+ _copy_server_fields
1718from ..setups import setup_exists , initialize_model
1819from ..exceptions import OpenMLCacheException , OpenMLServerException
19- from .._api_calls import _perform_api_call , _file_id_to_url
20+ from .._api_calls import _perform_api_call
2021from .run import OpenMLRun , _get_version_information
2122from .trace import OpenMLRunTrace , OpenMLTraceIteration
2223
2526# circular imports
2627
2728
28- def run_task (task , model , avoid_duplicate_runs = True , flow_tags = None , seed = None ):
29+ def run_model_on_task (task , model , avoid_duplicate_runs = True , flow_tags = None ,
30+ seed = None ):
31+ flow = sklearn_to_flow (model )
32+
33+ # returns flow id if the flow exists on the server, False otherwise
34+ flow_id = flow_exists (flow .name , flow .external_version )
35+
36+ if flow_id == False :
37+ # TODO this is potential race condition! someone could upload the
38+ # same flow in the meantime!
39+ # means the flow did not exists. As we could run it, publish it now
40+ flow = flow .publish ()
41+ else :
42+ # flow already existed, download it from server
43+ # TODO (neccessary? is this a post condition of this function)
44+ flow_from_server = get_flow (flow_id )
45+ _copy_server_fields (flow_from_server , flow )
46+
47+ return run_flow_on_task (task = task , flow = flow ,
48+ avoid_duplicate_runs = avoid_duplicate_runs ,
49+ flow_tags = flow_tags , seed = seed )
50+
51+
52+ def run_flow_on_task (task , flow , avoid_duplicate_runs = True , flow_tags = None ,
53+ seed = None ):
2954 """Performs a CV run on the dataset of the given task, using the split.
3055
3156 Parameters
@@ -51,23 +76,18 @@ def run_task(task, model, avoid_duplicate_runs=True, flow_tags=None, seed=None):
5176 """
5277 if flow_tags is not None and not isinstance (flow_tags , list ):
5378 raise ValueError ("flow_tags should be list" )
54- # TODO move this into its onwn module. While it somehow belongs here, it
55- # adds quite a lot of functionality which is better suited in other places!
56- # TODO why doesn't this accept a flow as input? - this would make this more flexible!
57- model = _get_seeded_model (model , seed )
58- flow = sklearn_to_flow (model )
5979
60- # returns flow id if the flow exists on the server, False otherwise
61- flow_id = flow_exists (flow .name , flow .external_version )
80+ flow .model = _get_seeded_model (flow .model , seed = seed )
6281
6382 # skips the run if it already exists and the user opts for this in the config file.
6483 # also, if the flow is not present on the server, the check is not needed.
65- if avoid_duplicate_runs and flow_id :
66- flow = get_flow (flow_id )
67- setup_id = setup_exists (flow , model )
84+ if avoid_duplicate_runs :
85+ flow_from_server = get_flow (flow . flow_id )
86+ setup_id = setup_exists (flow_from_server )
6887 ids = _run_exists (task .task_id , setup_id )
6988 if ids :
7089 raise PyOpenMLError ("Run already exists in server. Run id(s): %s" % str (ids ))
90+ _copy_server_fields (flow_from_server , flow )
7191
7292 dataset = task .get_dataset ()
7393
@@ -79,18 +99,12 @@ def run_task(task, model, avoid_duplicate_runs=True, flow_tags=None, seed=None):
7999 run_environment = _get_version_information ()
80100 tags = ['openml-python' , run_environment [1 ]]
81101 # execute the run
82- run = OpenMLRun (task_id = task .task_id , flow_id = None , dataset_id = dataset .dataset_id , model = model , tags = tags )
83- res = _run_task_get_arffcontent (model , task , class_labels )
102+ run = OpenMLRun (task_id = task .task_id , flow_id = None , dataset_id = dataset .dataset_id ,
103+ model = flow .model , tags = tags )
104+ run .parameter_settings = OpenMLRun ._parse_parameters (flow )
105+ res = _run_task_get_arffcontent (flow .model , task , class_labels )
84106 run .data_content , run .trace_content , run .trace_attributes , run .detailed_evaluations = res
85107
86- if flow_id == False :
87- # means the flow did not exists. As we could run it, publish it now
88- flow = flow .publish ()
89- else :
90- # flow already existed, download it from server
91- # TODO (neccessary? is this a post condition of this function)
92- flow = get_flow (flow_id )
93-
94108 run .flow_id = flow .flow_id
95109 config .logger .info ('Executed Task %d with Flow id: %d' % (task .task_id , run .flow_id ))
96110
0 commit comments