@@ -110,9 +110,9 @@ def _compare_predictions(self, predictions, predictions_prime):
110110
111111 return True
112112
113- def _rerun_model_and_compare_predictions (self , run_id , model_prime , seed ):
113+ def _rerun_model_and_compare_predictions (self , run_id , model_prime , seed ,
114+ create_task_obj ):
114115 run = openml .runs .get_run (run_id )
115- task = openml .tasks .get_task (run .task_id )
116116
117117 # TODO: assert holdout task
118118
@@ -121,12 +121,24 @@ def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed):
121121 predictions_url = openml ._api_calls ._file_id_to_url (file_id )
122122 response = openml ._api_calls ._download_text_file (predictions_url )
123123 predictions = arff .loads (response )
124- run_prime = openml .runs .run_model_on_task (
125- model = model_prime ,
126- task = task ,
127- avoid_duplicate_runs = False ,
128- seed = seed ,
129- )
124+
125+ # if create_task_obj=False, task argument in run_model_on_task is specified task_id
126+ if create_task_obj :
127+ task = openml .tasks .get_task (run .task_id )
128+ run_prime = openml .runs .run_model_on_task (
129+ model = model_prime ,
130+ task = task ,
131+ avoid_duplicate_runs = False ,
132+ seed = seed ,
133+ )
134+ else :
135+ run_prime = openml .runs .run_model_on_task (
136+ model = model_prime ,
137+ task = run .task_id ,
138+ avoid_duplicate_runs = False ,
139+ seed = seed ,
140+ )
141+
130142 predictions_prime = run_prime ._generate_arff_dict ()
131143
132144 self ._compare_predictions (predictions , predictions_prime )
@@ -425,13 +437,17 @@ def determine_grid_size(param_grid):
425437 raise e
426438
427439 self ._rerun_model_and_compare_predictions (run .run_id , model_prime ,
428- seed )
440+ seed , create_task_obj = True )
441+ self ._rerun_model_and_compare_predictions (run .run_id , model_prime ,
442+ seed , create_task_obj = False )
429443 else :
430444 run_downloaded = openml .runs .get_run (run .run_id )
431445 sid = run_downloaded .setup_id
432446 model_prime = openml .setups .initialize_model (sid )
433- self ._rerun_model_and_compare_predictions (run .run_id ,
434- model_prime , seed )
447+ self ._rerun_model_and_compare_predictions (run .run_id , model_prime ,
448+ seed , create_task_obj = True )
449+ self ._rerun_model_and_compare_predictions (run .run_id , model_prime ,
450+ seed , create_task_obj = False )
435451
436452 # todo: check if runtime is present
437453 self ._check_fold_timing_evaluations (run .fold_evaluations , 1 , num_folds ,
0 commit comments