Skip to content

Commit a972eeb

Browse files
committed
MAINT minor changes to code
1 parent 90a88cc commit a972eeb

4 files changed

Lines changed: 29 additions & 34 deletions

File tree

openml/flows/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .flow import OpenMLFlow
22
from .sklearn_converter import sklearn_to_flow, flow_to_sklearn
3-
from .functions import get_flow, get_flow_dict
3+
from .functions import get_flow
44

55
__all__ = ['OpenMLFlow', 'create_flow_from_model', 'get_flow',
66
'sklearn_to_flow', 'flow_to_sklearn']

openml/flows/functions.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,4 @@ def get_flow(flow_id):
2525
if 'sklearn' in flow.external_version:
2626
flow.model = flow_to_sklearn(flow)
2727

28-
return flow
29-
30-
31-
def get_flow_dict(flow):
32-
"""Returns a dictionary with keys flow name and values flow id.
33-
Parameters
34-
----------
35-
flow : OpenMLFlow
36-
"""
37-
if flow.flow_id is None:
38-
raise PyOpenMLError(
39-
"Can only invoke function 'get_flow_map' on a server downloaded flow. ")
40-
flow_map = {flow.name: flow.flow_id}
41-
for subflow in flow.components:
42-
flow_map.update(get_flow_dict(flow.components[subflow]))
43-
44-
return flow_map
28+
return flow

openml/runs/functions.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,8 @@ def _run_task_get_arffcontent(model, task, class_labels):
8282

8383
model.fit(trainX, trainY)
8484
if isinstance(model, BaseSearchCV):
85-
for itt_no in range(0, len(model.cv_results_['mean_test_score'])):
86-
# we use the string values for True and False, as it is defined in this way by the OpenML server
87-
selected = 'false'
88-
if itt_no == model.best_index_:
89-
selected = 'true'
90-
test_score = model.cv_results_['mean_test_score'][itt_no]
91-
arff_line = [rep_no, fold_no, itt_no, test_score, selected]
92-
for key in model.cv_results_:
93-
if key.startswith("param_"):
94-
arff_line.append(str(model.cv_results_[key][itt_no]))
95-
arff_tracecontent.append(arff_line)
85+
_add_results_to_arfftrace(arff_tracecontent, fold_no, model,
86+
rep_no)
9687

9788
ProbaY = model.predict_proba(testX)
9889
PredY = model.predict(testX)
@@ -113,6 +104,20 @@ def _run_task_get_arffcontent(model, task, class_labels):
113104
return arff_datacontent, arff_tracecontent
114105

115106

107+
def _add_results_to_arfftrace(arff_tracecontent, fold_no, model, rep_no):
108+
for itt_no in range(0, len(model.cv_results_['mean_test_score'])):
109+
# we use the string values for True and False, as it is defined in this way by the OpenML server
110+
selected = 'false'
111+
if itt_no == model.best_index_:
112+
selected = 'true'
113+
test_score = model.cv_results_['mean_test_score'][itt_no]
114+
arff_line = [rep_no, fold_no, itt_no, test_score, selected]
115+
for key in model.cv_results_:
116+
if key.startswith("param_"):
117+
arff_line.append(str(model.cv_results_[key][itt_no]))
118+
arff_tracecontent.append(arff_line)
119+
120+
116121
def get_runs(run_ids):
117122
"""Gets all runs in run_ids list.
118123

openml/runs/run.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ def _generate_trace_arff_dict(self, model):
110110
type = 'NUMERIC'
111111
else:
112112
values = list(set(model.cv_results_[key])) # unique values
113-
if len(values) < 100: # arbitrary number. make it an option?
114-
type = [str(i) for i in values]
113+
type = [str(i) for i in values]
115114
print(key + ": " + str(type))
116115

117116
attribute = ("parameter_" + key[6:], type)
@@ -179,19 +178,26 @@ def _create_description_xml(self):
179178
return description_xml
180179

181180
def _parse_parameters(model, flow):
182-
"""Extracts all parameter settings from an model in OpenML format.
181+
"""Extracts all parameter settings from a model in OpenML format.
183182
184183
Parameters
185184
----------
186185
model
187-
the sci-kit learn model (fitted)
186+
the scikit-learn model (fitted)
188187
flow
189188
openml flow object (containing flow ids, i.e., it has to be downloaded from the server)
190189
191190
"""
192191
python_param_settings = model.get_params()
193192
openml_param_settings = []
194-
flow_dict = openml.flows.get_flow_dict(flow)
193+
194+
def get_flow_dict(_flow):
195+
flow_map = {_flow.name: _flow.flow_id}
196+
for subflow in _flow.components:
197+
flow_map.update(get_flow_dict(_flow.components[subflow]))
198+
return flow_map
199+
200+
flow_dict = get_flow_dict(flow)
195201

196202
for param in python_param_settings:
197203
if "__" in param:

0 commit comments

Comments
 (0)