77import warnings
88
99import numpy as np
10- import sklearn
10+ import sklearn . pipeline
1111import six
1212import xmltodict
1313
14+ import openml
1415from ..exceptions import PyOpenMLError
1516from .. import config
16- from ..flows import sklearn_to_flow , get_flow , flow_exists , _check_n_jobs
17+ from ..flows import sklearn_to_flow , get_flow , flow_exists , _check_n_jobs , \
18+ _copy_server_fields
1719from ..setups import setup_exists , initialize_model
1820from ..exceptions import OpenMLCacheException , OpenMLServerException
19- from .._api_calls import _perform_api_call , _file_id_to_url
21+ from .._api_calls import _perform_api_call
2022from .run import OpenMLRun , _get_version_information
2123from .trace import OpenMLRunTrace , OpenMLTraceIteration
2224
2527# circular imports
2628
2729
28- def run_task (task , model , avoid_duplicate_runs = True , flow_tags = None , seed = None ):
29- """Performs a CV run on the dataset of the given task, using the split.
30+ def run_model_on_task (task , model , avoid_duplicate_runs = True , flow_tags = None ,
31+ seed = None ):
32+ """See ``run_flow_on_task for a documentation."""
33+
34+ flow = sklearn_to_flow (model )
35+
36+ return run_flow_on_task (task = task , flow = flow ,
37+ avoid_duplicate_runs = avoid_duplicate_runs ,
38+ flow_tags = flow_tags , seed = seed )
39+
40+
41+ def run_flow_on_task (task , flow , avoid_duplicate_runs = True , flow_tags = None ,
42+ seed = None ):
43+ """Run the model provided by the flow on the dataset defined by task.
44+
45+ Takes the flow and repeat information into account. In case a flow is not
46+ yet published, it is published after executing the run (requires
47+ internet connection).
3048
3149 Parameters
3250 ----------
3351 task : OpenMLTask
3452 Task to perform.
3553 model : sklearn model
36- a model which has a function fit(X,Y) and predict(X),
54+ A model which has a function fit(X,Y) and predict(X),
3755 all supervised estimators of scikit learn follow this definition of a model [1]
3856 [1](http://scikit-learn.org/stable/tutorial/statistical_inference/supervised_learning.html)
3957 avoid_duplicate_runs : bool
40- if this flag is set to True, the run will throw an error if the
41- setup/task combination is already present on the server.
58+ If this flag is set to True, the run will throw an error if the
59+ setup/task combination is already present on the server. Works only
60+ if the flow is already published on the server. This feature requires an
61+ internet connection.
4262 flow_tags : list(str)
43- a list of tags that the flow should have at creation
63+ A list of tags that the flow should have at creation.
4464 seed: int
45- the models that are not seeded will get this seed
65+ Models that are not seeded will get this seed.
4666
4767 Returns
4868 -------
@@ -51,23 +71,19 @@ def run_task(task, model, avoid_duplicate_runs=True, flow_tags=None, seed=None):
5171 """
5272 if flow_tags is not None and not isinstance (flow_tags , list ):
5373 raise ValueError ("flow_tags should be list" )
54- # TODO move this into its onwn module. While it somehow belongs here, it
55- # adds quite a lot of functionality which is better suited in other places!
56- # TODO why doesn't this accept a flow as input? - this would make this more flexible!
57- model = _get_seeded_model (model , seed )
58- flow = sklearn_to_flow (model )
5974
60- # returns flow id if the flow exists on the server, False otherwise
61- flow_id = flow_exists (flow .name , flow .external_version )
75+ flow .model = _get_seeded_model (flow .model , seed = seed )
6276
6377 # skips the run if it already exists and the user opts for this in the config file.
6478 # also, if the flow is not present on the server, the check is not needed.
79+ flow_id = flow_exists (flow .name , flow .external_version )
6580 if avoid_duplicate_runs and flow_id :
66- flow = get_flow (flow_id )
67- setup_id = setup_exists (flow , model )
81+ flow_from_server = get_flow (flow_id )
82+ setup_id = setup_exists (flow_from_server )
6883 ids = _run_exists (task .task_id , setup_id )
6984 if ids :
7085 raise PyOpenMLError ("Run already exists in server. Run id(s): %s" % str (ids ))
86+ _copy_server_fields (flow_from_server , flow )
7187
7288 dataset = task .get_dataset ()
7389
@@ -78,25 +94,43 @@ def run_task(task, model, avoid_duplicate_runs=True, flow_tags=None, seed=None):
7894
7995 run_environment = _get_version_information ()
8096 tags = ['openml-python' , run_environment [1 ]]
97+
8198 # execute the run
82- run = OpenMLRun (task_id = task .task_id , flow_id = None , dataset_id = dataset .dataset_id , model = model , tags = tags )
83- res = _run_task_get_arffcontent (model , task , class_labels )
84- run .data_content , run .trace_content , run .trace_attributes , run .detailed_evaluations = res
99+ res = _run_task_get_arffcontent (flow .model , task , class_labels )
85100
86- if flow_id == False :
87- # means the flow did not exists. As we could run it, publish it now
88- flow = flow .publish ()
89- else :
90- # flow already existed, download it from server
91- # TODO (neccessary? is this a post condition of this function)
92- flow = get_flow (flow_id )
101+ if flow .flow_id is None :
102+ _publish_flow_if_necessary (flow )
103+
104+ run = OpenMLRun (task_id = task .task_id , flow_id = flow .flow_id ,
105+ dataset_id = dataset .dataset_id , model = flow .model , tags = tags )
106+ run .parameter_settings = OpenMLRun ._parse_parameters (flow )
107+
108+ run .data_content , run .trace_content , run .trace_attributes , run .detailed_evaluations = res
93109
94- run .flow_id = flow .flow_id
95110 config .logger .info ('Executed Task %d with Flow id: %d' % (task .task_id , run .flow_id ))
96111
97112 return run
98113
99114
115+ def _publish_flow_if_necessary (flow ):
116+ # try publishing the flow if one has to assume it doesn't exist yet. It
117+ # might fail because it already exists, then the flow is currently not
118+ # reused
119+
120+ try :
121+ flow .publish ()
122+ except OpenMLServerException as e :
123+ if e .message == "flow already exists" :
124+ flow_id = openml .flows .flow_exists (flow .name ,
125+ flow .external_version )
126+ server_flow = get_flow (flow_id )
127+ openml .flows .flow ._copy_server_fields (server_flow , flow )
128+ openml .flows .assert_flows_equal (flow , server_flow ,
129+ ignore_parameters = True )
130+ else :
131+ raise e
132+
133+
100134def get_run_trace (run_id ):
101135 """Get the optimization trace object for a given run id.
102136
0 commit comments