66import warnings
77import sklearn
88import time
9- from sklearn .model_selection ._search import BaseSearchCV
109
1110from ..exceptions import PyOpenMLError
1211from .. import config
12+
1313from ..flows import sklearn_to_flow , get_flow , flow_exists , _check_n_jobs
14- from ..setups import setup_exists
14+ from ..setups import setup_exists , initialize_model
15+
1516from ..exceptions import OpenMLCacheException , OpenMLServerException
1617from ..util import URLError , version_complies
17- from ..tasks .functions import _create_task_from_xml
1818from .._api_calls import _perform_api_call
1919from .run import OpenMLRun , _get_version_information
2020
2424
2525
2626
27- def run_task (task , model , avoid_duplicate_runs = True , flow_tags = None ):
27+ def run_task (task , model , avoid_duplicate_runs = True , flow_tags = None , seed = None ):
2828 """Performs a CV run on the dataset of the given task, using the split.
2929
3030 Parameters
@@ -35,8 +35,13 @@ def run_task(task, model, avoid_duplicate_runs=True, flow_tags=None):
3535 a model which has a function fit(X,Y) and predict(X),
3636 all supervised estimators of scikit learn follow this definition of a model [1]
3737 [1](http://scikit-learn.org/stable/tutorial/statistical_inference/supervised_learning.html)
38+ avoid_duplicate_runs : bool
39+ if this flag is set to True, the run will throw an error if the
40+ setup/task combination is already present on the server.
3841 flow_tags : list(str)
3942 a list of tags that the flow should have at creation
43+ seed: int
44+ the models that are not seeded will get this seed
4045
4146 Returns
4247 -------
@@ -48,6 +53,7 @@ def run_task(task, model, avoid_duplicate_runs=True, flow_tags=None):
4853 # TODO move this into its onwn module. While it somehow belongs here, it
4954 # adds quite a lot of functionality which is better suited in other places!
5055 # TODO why doesn't this accept a flow as input? - this would make this more flexible!
56+ model = _get_seeded_model (model , seed )
5157 flow = sklearn_to_flow (model )
5258
5359 # returns flow id if the flow exists on the server, False otherwise
@@ -88,6 +94,24 @@ def run_task(task, model, avoid_duplicate_runs=True, flow_tags=None):
8894
8995 return run
9096
97+ def initialize_model_from_run (run_id ):
98+ '''
99+ Initialized a model based on a run_id (i.e., using the exact
100+ same parameter settings)
101+
102+ Parameters
103+ ----------
104+ run_id : int
105+ The Openml run_id
106+
107+ Returns
108+ -------
109+ model : sklearn model
110+ the scikitlearn model with all parameters initailized
111+ '''
112+ run = get_run (run_id )
113+ return initialize_model (run .setup_id )
114+
91115def _run_exists (task_id , setup_id ):
92116 '''
93117 Checks whether a task/setup combination is already present on the server.
@@ -111,6 +135,49 @@ def _run_exists(task_id, setup_id):
111135 assert (exception .code == 512 )
112136 return False
113137
138+ def _get_seeded_model (model , seed = None ):
139+ '''Sets all the non-seeded components of a model with a seed.
140+ Models that are already seeded will maintain the seed. In
141+ this case, only integer seeds are allowed (An exception
142+ is thrown when a RandomState was used as seed)
143+
144+ Parameters
145+ ----------
146+ model : sklearn model
147+ The model to be seeded
148+ seed : int
149+ The seed to initialize the RandomState with. Unseeded subcomponents
150+ will be seeded with a random number from the RandomState.
151+
152+ Returns
153+ -------
154+ model : sklearn model
155+ a version of the model where all (sub)components have
156+ a seed
157+ '''
158+
159+ rs = np .random .RandomState (seed )
160+ model_params = model .get_params ()
161+ random_states = {}
162+ for param_name in sorted (model_params ):
163+ if 'random_state' in param_name :
164+ currentValue = model_params [param_name ]
165+ # important to draw the value at this point (and not in the if statement)
166+ # this way we guarantee that if a different set of subflows is seeded,
167+ # the same number of the random generator is used
168+ newValue = rs .randint (0 , 2 ** 16 )
169+ if currentValue is None :
170+ random_states [param_name ] = newValue
171+ elif isinstance (currentValue , int ):
172+ # acceptable behaviour
173+ pass
174+ elif isinstance (currentValue , np .random .RandomState ):
175+ raise ValueError ('Models initialized with a RandomState object are not supported. Please seed with an integer. ' )
176+ else :
177+ raise ValueError ('Models should be seeded with int or None (this should never happen). ' )
178+ model .set_params (** random_states )
179+ return model
180+
114181
115182
116183def _prediction_to_row (rep_no , fold_no , row_id , correct_label , predicted_label ,
0 commit comments