@@ -122,37 +122,42 @@ def initialize_model_from_trace(run_id, repeat, fold, iteration=None):
122122 parameter settings)
123123
124124 Parameters
125- ----------
126- run_id : int
127- The Openml run_id. Should contain a trace file
125+ ----------
126+ run_id : int
127+ The Openml run_id. Should contain a trace file
128128
129- repeat: int
130- The repeat nr (column in trace file)
129+ repeat: int
130+ The repeat nr (column in trace file)
131131
132- fold: int
133- The fold nr (column in trace file)
132+ fold: int
133+ The fold nr (column in trace file)
134134
135- iteration: int
136- The iteration nr (column in trace file)
135+ iteration: int
136+ The iteration nr (column in trace file)
137137
138- Returns
139- -------
140- model : sklearn model
141- the scikitlearn model with all parameters initailized
142- '''
138+ Returns
139+ -------
140+ model : sklearn model
141+ the scikitlearn model with all parameters initailized
142+ '''
143143 run = get_run (run_id )
144144 if 'trace' not in run .output_files :
145145 raise PyOpenMLError ('Run does not contain trace file' )
146- trace_url = fileid_to_url (run .output_files ['trace' ], 'trace.arff' )
147- #print(trace_url)
148146 trace_xml = _perform_api_call ('run/trace/%d' % run_id )
149147 run_trace = _create_trace_from_description (trace_xml )
150148
151149 request = (repeat , fold , iteration )
152150 if request not in run_trace .trace_iterations :
153151 raise ValueError ('Combination repeat, fold, iteration not availavle' )
154152 current = run_trace .trace_iterations [(repeat , fold , iteration )]
155-
153+
154+ search_model = initialize_model_from_run (run_id )
155+ if not isinstance (search_model , sklearn .model_selection ._search .BaseSearchCV ):
156+ raise ValueError ('Deserialized flow not instance of ' \
157+ 'sklearn.model_selection._search.BaseSearchCV' )
158+ base_estimator = search_model .estimator
159+ base_estimator .set_params (** current .get_parameters ())
160+ return base_estimator
156161
157162def _run_exists (task_id , setup_id ):
158163 '''
@@ -347,8 +352,9 @@ def _extract_arfftrace(model, rep_no, fold_no):
347352 test_score = model .cv_results_ ['mean_test_score' ][itt_no ]
348353 arff_line = [rep_no , fold_no , itt_no , test_score , selected ]
349354 for key in model .cv_results_ :
350- if key .startswith ("param_" ):
351- arff_line .append (sklearn_to_flow (model .cv_results_ [key ][itt_no ]))
355+ if key .startswith ('param_' ):
356+ serialized_value = json .dumps (model .cv_results_ [key ][itt_no ])
357+ arff_line .append (serialized_value )
352358 arff_tracecontent .append (arff_line )
353359 return arff_tracecontent
354360
@@ -371,11 +377,7 @@ def _extract_arfftrace_attributes(model):
371377 if key .startswith ('param_' ):
372378 # supported types should include all types, including bool, int float
373379 supported_types = (bool , int , float , six .string_types )
374- if all (isinstance (i , (bool )) for i in model .cv_results_ [key ]):
375- type = ['True' , 'False' ]
376- elif all (isinstance (i , (int , float )) for i in model .cv_results_ [key ]):
377- type = 'NUMERIC'
378- elif all (isinstance (i , supported_types ) or i is None for i in model .cv_results_ [key ]):
380+ if all (isinstance (i , supported_types ) or i is None for i in model .cv_results_ [key ]):
379381 type = 'STRING'
380382 else :
381383 raise TypeError ('Unsupported param type in param grid' )
0 commit comments