Skip to content

Commit c878872

Browse files
committed
added unit test
1 parent f664396 commit c878872

2 files changed

Lines changed: 47 additions & 13 deletions

File tree

openml/runs/functions.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def obtain_field(xml_obj, fieldname, from_server, cast=None):
681681
else:
682682
raise AttributeError('Run XML does not contain required (server) field: ', fieldname)
683683

684-
run = xmltodict.parse(xml)["oml:run"]
684+
run = xmltodict.parse(xml, force_dict=['oml:file', 'oml:evaluation'])["oml:run"]
685685
run_id = obtain_field(run, 'oml:run_id', from_server, cast=int)
686686
uploader = obtain_field(run, 'oml:uploader', from_server, cast=int)
687687
uploader_name = obtain_field(run, 'oml:uploader_name', from_server)
@@ -722,17 +722,9 @@ def obtain_field(xml_obj, fieldname, from_server, cast=None):
722722
else:
723723
output_data = run['oml:output_data']
724724
if 'oml:file' in output_data:
725-
if isinstance(output_data['oml:file'], dict):
726-
# only one result.. probably due to an upload error
727-
file_dict = output_data['oml:file']
728-
files[file_dict['oml:name']] = int(file_dict['oml:file_id'])
729-
elif isinstance(output_data['oml:file'], list):
730-
# multiple files, the normal case
731-
for file_dict in output_data['oml:file']:
725+
# multiple files, the normal case
726+
for file_dict in output_data['oml:file']:
732727
files[file_dict['oml:name']] = int(file_dict['oml:file_id'])
733-
else:
734-
raise TypeError(type(output_data['oml:file']))
735-
736728
if 'oml:evaluation' in output_data:
737729
# in normal cases there should be evaluations, but in case there
738730
# was an error these could be absent

tests/test_runs/test_run_functions.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import arff
2+
import collections
23
import json
34
import random
45
import time
@@ -714,9 +715,49 @@ def test_run_with_classifiers_in_param_grid(self):
714715

715716
def test__run_task_get_arffcontent(self):
716717
task = openml.tasks.get_task(7)
717-
class_labels = task.class_labels
718+
num_instances = 320
719+
num_folds = 1
720+
num_repeats = 1
721+
722+
clf = SGDClassifier(loss='log', random_state=1)
723+
res = openml.runs.functions._run_model_on_fold(clf, task, 0, 0, 0, True)
724+
725+
arff_datacontent, arff_tracecontent, user_defined_measures, model = res
726+
# predictions
727+
self.assertIsInstance(arff_datacontent, list)
728+
# trace. SGD does not produce any
729+
self.assertIsInstance(arff_tracecontent, list)
730+
self.assertEquals(len(arff_tracecontent), 0)
731+
732+
fold_evaluations = collections.defaultdict(lambda: collections.defaultdict(dict))
733+
for measure in user_defined_measures:
734+
fold_evaluations[measure][0][0] = user_defined_measures[measure]
735+
736+
self._check_fold_evaluations(fold_evaluations, num_repeats, num_folds)
737+
738+
# 10 times 10 fold CV of 150 samples
739+
self.assertEqual(len(arff_datacontent), num_instances * num_repeats)
740+
for arff_line in arff_datacontent:
741+
# check number columns
742+
self.assertEqual(len(arff_line), 8)
743+
# check repeat
744+
self.assertGreaterEqual(arff_line[0], 0)
745+
self.assertLessEqual(arff_line[0], num_repeats - 1)
746+
# check fold
747+
self.assertGreaterEqual(arff_line[1], 0)
748+
self.assertLessEqual(arff_line[1], num_folds - 1)
749+
# check row id
750+
self.assertGreaterEqual(arff_line[2], 0)
751+
self.assertLessEqual(arff_line[2], num_instances - 1)
752+
# check confidences
753+
self.assertAlmostEqual(sum(arff_line[4:6]), 1.0)
754+
self.assertIn(arff_line[6], ['won', 'nowin'])
755+
self.assertIn(arff_line[7], ['won', 'nowin'])
756+
757+
def test__run_model_on_fold(self):
758+
task = openml.tasks.get_task(11)
718759
num_instances = 3196
719-
num_folds = 10
760+
num_folds = 1
720761
num_repeats = 1
721762

722763
clf = SGDClassifier(loss='log', random_state=1)
@@ -748,6 +789,7 @@ def test__run_task_get_arffcontent(self):
748789
self.assertIn(arff_line[6], ['won', 'nowin'])
749790
self.assertIn(arff_line[7], ['won', 'nowin'])
750791

792+
751793
def test__create_trace_from_arff(self):
752794
with open(self.static_cache_dir + '/misc/trace.arff', 'r') as arff_file:
753795
trace_arff = arff.load(arff_file)

0 commit comments

Comments
 (0)