Skip to content

Commit edbad39

Browse files
committed
updated to work with unit tests (+ small bugfixes)
1 parent b8fdd17 commit edbad39

5 files changed

Lines changed: 54 additions & 63 deletions

File tree

openml/runs/functions.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,8 @@ def _create_run_from_xml(xml):
594594

595595
files = dict()
596596
evaluations = dict()
597-
detailed_evaluations = defaultdict(lambda: defaultdict(dict))
597+
fold_evaluations = defaultdict(lambda: defaultdict(dict))
598+
sample_evaluations = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
598599
if 'oml:output_data' not in run:
599600
raise ValueError('Run does not contain output_data (OpenML server error?)')
600601
else:
@@ -621,11 +622,18 @@ def _create_run_from_xml(xml):
621622
else:
622623
raise ValueError('Could not find keys "value" or "array_data" '
623624
'in %s' % str(evaluation_dict.keys()))
624-
625-
if '@repeat' in evaluation_dict and '@fold' in evaluation_dict:
625+
if '@repeat' in evaluation_dict and '@fold' in evaluation_dict and '@sample' in evaluation_dict:
626+
repeat = int(evaluation_dict['@repeat'])
627+
fold = int(evaluation_dict['@fold'])
628+
sample = int(evaluation_dict['@sample'])
629+
repeat_dict = sample_evaluations[key]
630+
fold_dict = repeat_dict[repeat]
631+
sample_dict = fold_dict[fold]
632+
sample_dict[sample] = value
633+
elif '@repeat' in evaluation_dict and '@fold' in evaluation_dict:
626634
repeat = int(evaluation_dict['@repeat'])
627635
fold = int(evaluation_dict['@fold'])
628-
repeat_dict = detailed_evaluations[key]
636+
repeat_dict = fold_evaluations[key]
629637
fold_dict = repeat_dict[repeat]
630638
fold_dict[fold] = value
631639
else:
@@ -652,7 +660,9 @@ def _create_run_from_xml(xml):
652660
parameter_settings=parameters,
653661
dataset_id=dataset_id, output_files=files,
654662
evaluations=evaluations,
655-
detailed_evaluations=detailed_evaluations, tags=tags)
663+
fold_evaluations=fold_evaluations,
664+
sample_evaluations=sample_evaluations,
665+
tags=tags)
656666

657667
def _create_trace_from_description(xml):
658668
result_dict = xmltodict.parse(xml)['oml:trace']

openml/tasks/split.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ def __eq__(self, other):
4848
return False
4949
else:
5050
for fold in self.split[repetition]:
51-
if np.all(self.split[repetition][fold].test !=
52-
other.split[repetition][fold].test)\
53-
and \
54-
np.all(self.split[repetition][fold].train
55-
!= other.split[repetition][fold].train):
56-
return False
51+
for sample in self.split[repetition][fold]:
52+
if np.all(self.split[repetition][fold][sample].test !=
53+
other.split[repetition][fold][sample].test)\
54+
and \
55+
np.all(self.split[repetition][fold][sample].train
56+
!= other.split[repetition][fold][sample].train):
57+
return False
5758
return True
5859

5960
@classmethod

tests/test_runs/test_run_functions.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _remove_random_state(flow):
158158
return run
159159

160160

161-
def _check_detailed_evaluations(self, detailed_evaluations, num_repeats, num_folds, max_time_allowed=60000):
161+
def _check_fold_evaluations(self, fold_evaluations, num_repeats, num_folds, max_time_allowed=60000):
162162
'''
163163
Checks whether the right timing measures are attached to the run (before upload).
164164
Test is only performed for versions >= Python3.3
@@ -169,17 +169,17 @@ def _check_detailed_evaluations(self, detailed_evaluations, num_repeats, num_fol
169169
'''
170170
timing_measures = {'usercpu_time_millis_testing', 'usercpu_time_millis_training', 'usercpu_time_millis'}
171171

172-
self.assertIsInstance(detailed_evaluations, dict)
172+
self.assertIsInstance(fold_evaluations, dict)
173173
if sys.version_info[:2] >= (3, 3):
174-
self.assertEquals(set(detailed_evaluations.keys()), timing_measures)
174+
self.assertEquals(set(fold_evaluations.keys()), timing_measures)
175175
for measure in timing_measures:
176-
num_rep_entrees = len(detailed_evaluations[measure])
176+
num_rep_entrees = len(fold_evaluations[measure])
177177
self.assertEquals(num_rep_entrees, num_repeats)
178178
for rep in range(num_rep_entrees):
179-
num_fold_entrees = len(detailed_evaluations[measure][rep])
179+
num_fold_entrees = len(fold_evaluations[measure][rep])
180180
self.assertEquals(num_fold_entrees, num_folds)
181181
for fold in range(num_fold_entrees):
182-
evaluation = detailed_evaluations[measure][rep][fold]
182+
evaluation = fold_evaluations[measure][rep][fold]
183183
self.assertIsInstance(evaluation, float)
184184
self.assertGreater(evaluation, 0) # should take at least one millisecond (?)
185185
self.assertLess(evaluation, max_time_allowed)
@@ -292,7 +292,7 @@ def test_run_and_upload(self):
292292
self.assertTrue(check_res)
293293

294294
# todo: check if runtime is present
295-
self._check_detailed_evaluations(run.detailed_evaluations, 1, num_folds)
295+
self._check_fold_evaluations(run.fold_evaluations, 1, num_folds)
296296
pass
297297

298298
def test_initialize_cv_from_run(self):
@@ -523,18 +523,20 @@ def test__prediction_to_row(self):
523523

524524
probaY = clf.predict_proba(test_X)
525525
predY = clf.predict(test_X)
526+
sample_nr = 0 # default for this task
526527
for idx in range(0, len(test_X)):
527-
arff_line = _prediction_to_row(repeat_nr, fold_nr, idx,
528+
arff_line = _prediction_to_row(repeat_nr, fold_nr, sample_nr, idx,
528529
task.class_labels[test_y[idx]],
529530
predY[idx], probaY[idx], task.class_labels, clf.classes_)
530531

531532
self.assertIsInstance(arff_line, list)
532-
self.assertEqual(len(arff_line), 5 + len(task.class_labels))
533+
self.assertEqual(len(arff_line), 6 + len(task.class_labels))
533534
self.assertEqual(arff_line[0], repeat_nr)
534535
self.assertEqual(arff_line[1], fold_nr)
535-
self.assertEqual(arff_line[2], idx)
536+
self.assertEqual(arff_line[2], sample_nr)
537+
self.assertEqual(arff_line[3], idx)
536538
sum = 0.0
537-
for att_idx in range(3, 3 + len(task.class_labels)):
539+
for att_idx in range(4, 4 + len(task.class_labels)):
538540
self.assertIsInstance(arff_line[att_idx], float)
539541
self.assertGreaterEqual(arff_line[att_idx], 0.0)
540542
self.assertLessEqual(arff_line[att_idx], 1.0)
@@ -572,19 +574,19 @@ def test__run_task_get_arffcontent(self):
572574

573575
clf = SGDClassifier(loss='log', random_state=1)
574576
res = openml.runs.functions._run_task_get_arffcontent(clf, task, class_labels)
575-
arff_datacontent, arff_tracecontent, _, detailed_evaluations = res
577+
arff_datacontent, arff_tracecontent, _, fold_evaluations, sample_evaluations = res
576578
# predictions
577579
self.assertIsInstance(arff_datacontent, list)
578580
# trace. SGD does not produce any
579581
self.assertIsInstance(arff_tracecontent, type(None))
580582

581-
self._check_detailed_evaluations(detailed_evaluations, num_repeats, num_folds)
583+
self._check_fold_evaluations(fold_evaluations, num_repeats, num_folds)
582584

583585
# 10 times 10 fold CV of 150 samples
584586
self.assertEqual(len(arff_datacontent), num_instances * num_repeats)
585587
for arff_line in arff_datacontent:
586588
# check number columns
587-
self.assertEqual(len(arff_line), 7)
589+
self.assertEqual(len(arff_line), 8)
588590
# check repeat
589591
self.assertGreaterEqual(arff_line[0], 0)
590592
self.assertLessEqual(arff_line[0], num_repeats - 1)
@@ -595,9 +597,9 @@ def test__run_task_get_arffcontent(self):
595597
self.assertGreaterEqual(arff_line[2], 0)
596598
self.assertLessEqual(arff_line[2], num_instances - 1)
597599
# check confidences
598-
self.assertAlmostEqual(sum(arff_line[3:5]), 1.0)
599-
self.assertIn(arff_line[5], ['won', 'nowin'])
600+
self.assertAlmostEqual(sum(arff_line[4:6]), 1.0)
600601
self.assertIn(arff_line[6], ['won', 'nowin'])
602+
self.assertIn(arff_line[7], ['won', 'nowin'])
601603

602604
def test_get_run(self):
603605
# this run is not available on test
@@ -615,7 +617,7 @@ def test_get_run(self):
615617
(7, 0.666365),
616618
(8, 0.56759),
617619
(9, 0.64621)]:
618-
self.assertEqual(run.detailed_evaluations['f_measure'][0][i], value)
620+
self.assertEqual(run.fold_evaluations['f_measure'][0][i], value)
619621
assert('weka' in run.tags)
620622
assert('stacking' in run.tags)
621623

@@ -742,11 +744,11 @@ def test_run_on_dataset_with_missing_labels(self):
742744
model = Pipeline(steps=[('Imputer', Imputer(strategy='median')),
743745
('Estimator', DecisionTreeClassifier())])
744746

745-
data_content, _, _, _ = _run_task_get_arffcontent(model, task, class_labels)
747+
data_content, _, _, _, _ = _run_task_get_arffcontent(model, task, class_labels)
746748
# 2 folds, 5 repeats; keep in mind that this task comes from the test
747749
# server, the task on the live server is different
748750
self.assertEqual(len(data_content), 4490)
749751
for row in data_content:
750752
# repeat, fold, row_id, 6 confidences, prediction and correct label
751-
self.assertEqual(len(row), 11)
753+
self.assertEqual(len(row), 12)
752754

tests/test_tasks/test_split.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,19 @@ def test_from_arff_file(self):
4646
split = OpenMLSplit._from_arff_file(self.arff_filename)
4747
self.assertIsInstance(split.split, dict)
4848
self.assertIsInstance(split.split[0], dict)
49-
self.assertIsInstance(split.split[0][0][0], np.ndarray)
50-
self.assertIsInstance(split.split[0][0].train, np.ndarray)
51-
self.assertIsInstance(split.split[0][0].train, np.ndarray)
52-
self.assertIsInstance(split.split[0][0][1], np.ndarray)
53-
self.assertIsInstance(split.split[0][0].test, np.ndarray)
54-
self.assertIsInstance(split.split[0][0].test, np.ndarray)
49+
self.assertIsInstance(split.split[0][0], dict)
50+
self.assertIsInstance(split.split[0][0][0][0], np.ndarray)
51+
self.assertIsInstance(split.split[0][0][0].train, np.ndarray)
52+
self.assertIsInstance(split.split[0][0][0].train, np.ndarray)
53+
self.assertIsInstance(split.split[0][0][0][1], np.ndarray)
54+
self.assertIsInstance(split.split[0][0][0].test, np.ndarray)
55+
self.assertIsInstance(split.split[0][0][0].test, np.ndarray)
5556
for i in range(10):
5657
for j in range(10):
57-
self.assertGreaterEqual(split.split[i][j].train.shape[0], 808)
58-
self.assertGreaterEqual(split.split[i][j].test.shape[0], 89)
59-
self.assertEqual(split.split[i][j].train.shape[0] +
60-
split.split[i][j].test.shape[0], 898)
58+
self.assertGreaterEqual(split.split[i][j][0].train.shape[0], 808)
59+
self.assertGreaterEqual(split.split[i][j][0].test.shape[0], 89)
60+
self.assertEqual(split.split[i][j][0].train.shape[0] +
61+
split.split[i][j][0].test.shape[0], 898)
6162

6263
def test_get_split(self):
6364
split = OpenMLSplit._from_arff_file(self.arff_filename)

tests/test_tasks/test_task.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -62,26 +62,3 @@ def test_get_train_and_test_split_indices(self):
6262
self.assertRaisesRegexp(ValueError, "Repeat 10 not known",
6363
task.get_train_test_split_indices, 0, 10)
6464

65-
def test_iterate_repeats(self):
66-
openml.config.set_cache_directory(self.static_cache_dir)
67-
task = openml.tasks.get_task(1882)
68-
69-
num_repeats = 0
70-
for rep in task.iterate_repeats():
71-
num_repeats += 1
72-
self.assertIsInstance(rep, types.GeneratorType)
73-
self.assertEqual(num_repeats, 10)
74-
75-
def test_iterate_all_splits(self):
76-
openml.config.set_cache_directory(self.static_cache_dir)
77-
task = openml.tasks.get_task(1882)
78-
79-
num_splits = 0
80-
for split in task.iterate_all_splits():
81-
num_splits += 1
82-
self.assertIsInstance(split[0], np.ndarray)
83-
self.assertIsInstance(split[1], np.ndarray)
84-
self.assertEqual(split[0].shape[0] + split[1].shape[0], 898)
85-
self.assertEqual(num_splits, 100)
86-
87-

0 commit comments

Comments
 (0)