44import xmltodict
55import numpy as np
66import warnings
7- import openml
8- from sklearn .model_selection ._search import BaseSearchCV
7+ import sklearn
98
109from ..exceptions import PyOpenMLError
1110from .. import config
@@ -59,7 +58,6 @@ def run_task(task, model, avoid_duplicate_runs=True):
5958 raise PyOpenMLError ("Run already exists in server. Run id(s): %s" % str (ids ))
6059
6160 dataset = task .get_dataset ()
62- X , Y = dataset .get_data (target = task .target_name )
6361
6462 class_labels = task .class_labels
6563 if class_labels is None :
@@ -68,19 +66,19 @@ def run_task(task, model, avoid_duplicate_runs=True):
6866
6967 # execute the run
7068 run = OpenMLRun (task_id = task .task_id , flow_id = None , dataset_id = dataset .dataset_id , model = model )
71- run .data_content , run .trace_content = _run_task_get_arffcontent (model , task , class_labels )
69+ run .data_content , run .trace_content , run .trace_attributes = _run_task_get_arffcontent (model , task , class_labels )
70+
7271
7372 if flow_id == False :
74- # means the flow did not exists.
75- # As we could run it, publish it now
73+ # means the flow did not exists. As we could run it, publish it now
7674 flow = flow .publish ()
7775 else :
7876 # flow already existed, download it from server
7977 # TODO (neccessary? is this a post condition of this function)
8078 flow = get_flow (flow_id )
8179
8280 run .flow_id = flow .flow_id
83- config .logger .info ('Executed Task %d with Flow id: %d' % (task .task_id , run .flow_id ))
81+ config .logger .info ('Executed Task %d with Flow id: %d' % (task .task_id , run .flow_id ))
8482
8583 return run
8684
@@ -160,22 +158,27 @@ def _run_task_get_arffcontent(model, task, class_labels):
160158 for rep in task .iterate_repeats ():
161159 fold_no = 0
162160 for fold in rep :
161+ model_fold = sklearn .base .clone (model , safe = True )
163162 train_indices , test_indices = fold
164163 trainX = X [train_indices ]
165164 trainY = Y [train_indices ]
166165 testX = X [test_indices ]
167166 testY = Y [test_indices ]
168167
169- model .fit (trainX , trainY )
168+ try :
169+ model_fold .fit (trainX , trainY )
170170
171- if isinstance (model , BaseSearchCV ):
172- _add_results_to_arfftrace (arff_tracecontent , fold_no , model , rep_no )
173- model_classes = model .best_estimator_ .classes_
174- else :
175- model_classes = model .classes_
171+ if isinstance (model_fold , sklearn .model_selection ._search .BaseSearchCV ):
172+ arff_tracecontent .extend (_extract_arfftrace (model_fold , rep_no , fold_no ))
173+ model_classes = model_fold .best_estimator_ .classes_
174+ else :
175+ model_classes = model_fold .classes_
176+ except AttributeError as e :
177+ # typically happens when training a regressor on classification task
178+ raise PyOpenMLError (str (e ))
176179
177- ProbaY = model .predict_proba (testX )
178- PredY = model .predict (testX )
180+ ProbaY = model_fold .predict_proba (testX )
181+ PredY = model_fold .predict (testX )
179182 if ProbaY .shape [1 ] != len (class_labels ):
180183 warnings .warn ("Repeat %d Fold %d: estimator only predicted for %d/%d classes!" % (rep_no , fold_no , ProbaY .shape [1 ], len (class_labels )))
181184
@@ -186,13 +189,18 @@ def _run_task_get_arffcontent(model, task, class_labels):
186189 fold_no = fold_no + 1
187190 rep_no = rep_no + 1
188191
189- if not isinstance (model , BaseSearchCV ):
192+ if isinstance (model_fold , sklearn .model_selection ._search .BaseSearchCV ):
193+ # arff_tracecontent is already set
194+ arff_trace_attributes = _extract_arfftrace_attributes (model_fold )
195+ else :
190196 arff_tracecontent = None
197+ arff_trace_attributes = None
191198
192- return arff_datacontent , arff_tracecontent
199+ return arff_datacontent , arff_tracecontent , arff_trace_attributes
193200
194201
195- def _add_results_to_arfftrace (arff_tracecontent , fold_no , model , rep_no ):
202+ def _extract_arfftrace (model , rep_no , fold_no ):
203+ arff_tracecontent = []
196204 for itt_no in range (0 , len (model .cv_results_ ['mean_test_score' ])):
197205 # we use the string values for True and False, as it is defined in this way by the OpenML server
198206 selected = 'false'
@@ -204,6 +212,30 @@ def _add_results_to_arfftrace(arff_tracecontent, fold_no, model, rep_no):
204212 if key .startswith ("param_" ):
205213 arff_line .append (str (model .cv_results_ [key ][itt_no ]))
206214 arff_tracecontent .append (arff_line )
215+ return arff_tracecontent
216+
217+ def _extract_arfftrace_attributes (model ):
218+ # attributes that will be in trace arff, regardless of the model
219+ trace_attributes = [('repeat' , 'NUMERIC' ),
220+ ('fold' , 'NUMERIC' ),
221+ ('iteration' , 'NUMERIC' ),
222+ ('evaluation' , 'NUMERIC' ),
223+ ('selected' , ['true' , 'false' ])]
224+
225+ # model dependent attributes for trace arff
226+ for key in model .cv_results_ :
227+ if key .startswith ("param_" ):
228+ if all (isinstance (i , (bool )) for i in model .cv_results_ [key ]):
229+ type = ['True' , 'False' ]
230+ elif all (isinstance (i , (int , float )) for i in model .cv_results_ [key ]):
231+ type = 'NUMERIC'
232+ else :
233+ values = list (set (model .cv_results_ [key ])) # unique values
234+ type = [str (i ) for i in values ]
235+
236+ attribute = ("parameter_" + key [6 :], type )
237+ trace_attributes .append (attribute )
238+ return trace_attributes
207239
208240
209241def get_runs (run_ids ):
@@ -306,9 +338,16 @@ def _create_run_from_xml(xml):
306338 dataset_id = int (run ['oml:input_data' ]['oml:dataset' ]['oml:did' ])
307339
308340 predictions_url = None
309- for file_dict in run ['oml:output_data' ]['oml:file' ]:
341+ if isinstance (run ['oml:output_data' ]['oml:file' ], dict ):
342+ # only one result.. probably due to an upload error
343+ file_dict = run ['oml:output_data' ]['oml:file' ]
310344 if file_dict ['oml:name' ] == 'predictions' :
311345 predictions_url = file_dict ['oml:url' ]
346+ else :
347+ # multiple files, the normal case
348+ for file_dict in run ['oml:output_data' ]['oml:file' ]:
349+ if file_dict ['oml:name' ] == 'predictions' :
350+ predictions_url = file_dict ['oml:url' ]
312351 if predictions_url is None :
313352 raise ValueError ('No URL to download predictions for run %d in run '
314353 'description XML' % run_id )
0 commit comments