Skip to content

Commit 094772c

Browse files
committed
fixed errors that occured by solving merge conflict
1 parent 03d7807 commit 094772c

1 file changed

Lines changed: 14 additions & 6 deletions

File tree

openml/runs/functions.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def run_task(task, model, avoid_duplicate_runs=True, flow_tags=None):
4141
run : OpenMLRun
4242
Result of the run.
4343
"""
44-
if not isinstance(flow_tags, list):
44+
if flow_tags is not None and not isinstance(flow_tags, list):
4545
raise ValueError("flow_tags should be list")
4646
# TODO move this into its onwn module. While it somehow belongs here, it
4747
# adds quite a lot of functionality which is better suited in other places!
@@ -183,10 +183,9 @@ def _run_task_get_arffcontent(model, task, class_labels):
183183
raise PyOpenMLError(str(e))
184184

185185
# extract trace
186-
traceable_model = get_traceble_model(model_fold)
187-
if traceable_model:
188-
arff_tracecontent.extend(_extract_arfftrace(traceable_model, rep_no, fold_no))
189-
model_classes = traceable_model.best_estimator_.classes_
186+
if isinstance(model_fold, sklearn.model_selection._search.BaseSearchCV):
187+
arff_tracecontent.extend(_extract_arfftrace(model_fold, rep_no, fold_no))
188+
model_classes = model_fold.best_estimator_.classes_
190189
else:
191190
model_classes = model_fold.classes_
192191

@@ -204,7 +203,7 @@ def _run_task_get_arffcontent(model, task, class_labels):
204203

205204
if isinstance(model_fold, sklearn.model_selection._search.BaseSearchCV):
206205
# arff_tracecontent is already set
207-
arff_trace_attributes = _extract_arfftrace_attributes(traceable_model)
206+
arff_trace_attributes = _extract_arfftrace_attributes(model_fold)
208207
else:
209208
arff_tracecontent = None
210209
arff_trace_attributes = None
@@ -401,6 +400,15 @@ def _create_run_from_xml(xml):
401400
evaluation_flows[key] = flow_id
402401

403402
evaluation_flows[key] = flow_id
403+
tags = None
404+
if 'oml:tag' in run:
405+
if isinstance(run['oml:tag'], str):
406+
tags = [run['oml:tag']]
407+
elif isinstance(run['oml:tag'], list):
408+
tags = run['oml:tag']
409+
else:
410+
raise ValueError('Received not string and non list as tag item')
411+
404412

405413
return OpenMLRun(run_id=run_id, uploader=uploader,
406414
uploader_name=uploader_name, task_id=task_id,

0 commit comments

Comments
 (0)