Skip to content

Commit ddace5f

Browse files
committed
finalized reinstantiate from param trace function
1 parent 08f8ecb commit ddace5f

3 files changed

Lines changed: 37 additions & 26 deletions

File tree

openml/runs/functions.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -122,37 +122,42 @@ def initialize_model_from_trace(run_id, repeat, fold, iteration=None):
122122
parameter settings)
123123
124124
Parameters
125-
----------
126-
run_id : int
127-
The Openml run_id. Should contain a trace file
125+
----------
126+
run_id : int
127+
The Openml run_id. Should contain a trace file
128128
129-
repeat: int
130-
The repeat nr (column in trace file)
129+
repeat: int
130+
The repeat nr (column in trace file)
131131
132-
fold: int
133-
The fold nr (column in trace file)
132+
fold: int
133+
The fold nr (column in trace file)
134134
135-
iteration: int
136-
The iteration nr (column in trace file)
135+
iteration: int
136+
The iteration nr (column in trace file)
137137
138-
Returns
139-
-------
140-
model : sklearn model
141-
the scikitlearn model with all parameters initailized
142-
'''
138+
Returns
139+
-------
140+
model : sklearn model
141+
the scikitlearn model with all parameters initailized
142+
'''
143143
run = get_run(run_id)
144144
if 'trace' not in run.output_files:
145145
raise PyOpenMLError('Run does not contain trace file')
146-
trace_url = fileid_to_url(run.output_files['trace'], 'trace.arff')
147-
#print(trace_url)
148146
trace_xml = _perform_api_call('run/trace/%d' %run_id)
149147
run_trace = _create_trace_from_description(trace_xml)
150148

151149
request = (repeat, fold, iteration)
152150
if request not in run_trace.trace_iterations:
153151
raise ValueError('Combination repeat, fold, iteration not availavle')
154152
current = run_trace.trace_iterations[(repeat, fold, iteration)]
155-
153+
154+
search_model = initialize_model_from_run(run_id)
155+
if not isinstance(search_model, sklearn.model_selection._search.BaseSearchCV):
156+
raise ValueError('Deserialized flow not instance of ' \
157+
'sklearn.model_selection._search.BaseSearchCV')
158+
base_estimator = search_model.estimator
159+
base_estimator.set_params(**current.get_parameters())
160+
return base_estimator
156161

157162
def _run_exists(task_id, setup_id):
158163
'''
@@ -347,8 +352,9 @@ def _extract_arfftrace(model, rep_no, fold_no):
347352
test_score = model.cv_results_['mean_test_score'][itt_no]
348353
arff_line = [rep_no, fold_no, itt_no, test_score, selected]
349354
for key in model.cv_results_:
350-
if key.startswith("param_"):
351-
arff_line.append(sklearn_to_flow(model.cv_results_[key][itt_no]))
355+
if key.startswith('param_'):
356+
serialized_value = json.dumps(model.cv_results_[key][itt_no])
357+
arff_line.append(serialized_value)
352358
arff_tracecontent.append(arff_line)
353359
return arff_tracecontent
354360

@@ -371,11 +377,7 @@ def _extract_arfftrace_attributes(model):
371377
if key.startswith('param_'):
372378
# supported types should include all types, including bool, int float
373379
supported_types = (bool, int, float, six.string_types)
374-
if all(isinstance(i, (bool)) for i in model.cv_results_[key]):
375-
type = ['True', 'False']
376-
elif all(isinstance(i, (int, float)) for i in model.cv_results_[key]):
377-
type = 'NUMERIC'
378-
elif all(isinstance(i, supported_types) or i is None for i in model.cv_results_[key]):
380+
if all(isinstance(i, supported_types) or i is None for i in model.cv_results_[key]):
379381
type = 'STRING'
380382
else:
381383
raise TypeError('Unsupported param type in param grid')

openml/runs/trace.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
1+
import json
22

33
class OpenMLRunTrace(object):
44
"""OpenML Run Trace: parsed output from Run Trace call
@@ -33,6 +33,16 @@ def __init__(self, repeat, fold, iteration, setup_string, evaluation, selected):
3333
self.evaluation = evaluation
3434
self.selected = selected
3535

36+
def get_parameters(self):
37+
result = {}
38+
# parameters have prefix 'parameter_'
39+
prefix = 'parameter_'
40+
41+
for param in self.setup_string:
42+
key = param[len(prefix):]
43+
result[key] = json.loads(self.setup_string[param])
44+
return result
45+
3646
def __str__(self):
3747
return '[(%d,%d,%d): %f (%r)]' %(self.repeat, self.fold, self.iteration,
3848
self.evaluation, self.selected)

tests/test_runs/test_run_functions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def test_run_optimize_randomforest_iris(self):
104104
n_iter=num_iterations)
105105

106106
run = self._perform_run(task_id, num_instances, random_search)
107-
print(run.trace_content)
108107
self.assertEqual(len(run.trace_content), num_iterations * num_folds)
109108

110109
def test_run_optimize_bagging_iris(self):

0 commit comments

Comments
 (0)