Skip to content

Commit 385b375

Browse files
authored
Merge branch 'develop' into joaquinvanschoren-patch-1
2 parents 23687ba + 1ded29a commit 385b375

5 files changed

Lines changed: 580 additions & 4 deletions

File tree

openml/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Version information."""
22

33
# The following line *must* be the last in the module, exactly as formatted:
4-
__version__ = "0.5.0dev"
4+
__version__ = "0.6.0dev"

openml/runs/functions.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def initialize_model_from_trace(run_id, repeat, fold, iteration=None):
221221
base_estimator.set_params(**current.get_parameters())
222222
return base_estimator
223223

224+
224225
def _run_exists(task_id, setup_id):
225226
'''
226227
Checks whether a task/setup combination is already present on the server.
@@ -244,6 +245,7 @@ def _run_exists(task_id, setup_id):
244245
assert(exception.code == 512)
245246
return False
246247

248+
247249
def _get_seeded_model(model, seed=None):
248250
'''Sets all the non-seeded components of a model with a seed.
249251
Models that are already seeded will maintain the seed. In
@@ -356,6 +358,7 @@ def _prediction_to_row(rep_no, fold_no, sample_no, row_id, correct_label,
356358
arff_line.append(correct_label)
357359
return arff_line
358360

361+
359362
# JvR: why is class labels a parameter? could be removed and taken from task object, right?
360363
def _run_task_get_arffcontent(model, task, class_labels):
361364
X, Y = task.get_X_and_y()
@@ -467,7 +470,6 @@ def _calculate_local_measure(sklearn_fn, openml_name):
467470
user_defined_measures_sample
468471

469472

470-
471473
def _extract_arfftrace(model, rep_no, fold_no):
472474
if not isinstance(model, sklearn.model_selection._search.BaseSearchCV):
473475
raise ValueError('model should be instance of'\
@@ -490,6 +492,7 @@ def _extract_arfftrace(model, rep_no, fold_no):
490492
arff_tracecontent.append(arff_line)
491493
return arff_tracecontent
492494

495+
493496
def _extract_arfftrace_attributes(model):
494497
if not isinstance(model, sklearn.model_selection._search.BaseSearchCV):
495498
raise ValueError('model should be instance of'\
@@ -682,6 +685,7 @@ def _create_run_from_xml(xml):
682685
sample_evaluations=sample_evaluations,
683686
tags=tags)
684687

688+
685689
def _create_trace_from_description(xml):
686690
result_dict = xmltodict.parse(xml)['oml:trace']
687691

@@ -714,6 +718,52 @@ def _create_trace_from_description(xml):
714718

715719
return OpenMLRunTrace(run_id, trace)
716720

721+
722+
def _create_trace_from_arff(arff_obj):
723+
"""
724+
Creates a trace file from arff obj (for example, generated by a local run)
725+
726+
Parameters
727+
----------
728+
arff_obj : dict
729+
LIAC arff obj, dict containing attributes, relation, data and description
730+
731+
Returns
732+
-------
733+
run : OpenMLRunTrace
734+
Object containing None for run id and a dict containing the trace iterations
735+
"""
736+
trace = dict()
737+
attribute_idx = {att[0]: idx for idx, att in enumerate(arff_obj['attributes'])}
738+
for required_attribute in ['repeat', 'fold', 'iteration', 'evaluation', 'selected']:
739+
if required_attribute not in attribute_idx:
740+
raise ValueError('arff misses required attribute: %s' %required_attribute)
741+
742+
for itt in arff_obj['data']:
743+
repeat = int(itt[attribute_idx['repeat']])
744+
fold = int(itt[attribute_idx['fold']])
745+
iteration = int(itt[attribute_idx['iteration']])
746+
evaluation = float(itt[attribute_idx['evaluation']])
747+
selectedValue = itt[attribute_idx['selected']]
748+
if selectedValue == 'true':
749+
selected = True
750+
elif selectedValue == 'false':
751+
selected = False
752+
else:
753+
raise ValueError('expected {"true", "false"} value for selected field, received: %s' % selectedValue)
754+
755+
# TODO: if someone needs it, he can use the parameter
756+
# fields to revive the setup_string as well
757+
# However, this is usually done by the OpenML server
758+
# and if we are going to duplicate this functionality
759+
# it needs proper testing
760+
761+
current = OpenMLTraceIteration(repeat, fold, iteration, None, evaluation, selected)
762+
trace[(repeat, fold, iteration)] = current
763+
764+
return OpenMLRunTrace(None, trace)
765+
766+
717767
def _get_cached_run(run_id):
718768
"""Load a run from the cache."""
719769
cache_dir = config.get_cache_directory()

0 commit comments

Comments
 (0)