Skip to content

Commit 9f0ada9

Browse files
committed
fix unittest (?)
1 parent c878872 commit 9f0ada9

1 file changed

Lines changed: 17 additions & 18 deletions

File tree

tests/test_runs/test_run_functions.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -714,24 +714,18 @@ def test_run_with_classifiers_in_param_grid(self):
714714
task=task, model=clf, avoid_duplicate_runs=False)
715715

716716
def test__run_task_get_arffcontent(self):
717-
task = openml.tasks.get_task(7)
718-
num_instances = 320
719-
num_folds = 1
717+
task = openml.tasks.get_task(11)
718+
num_instances = 3196
719+
num_folds = 10
720720
num_repeats = 1
721721

722722
clf = SGDClassifier(loss='log', random_state=1)
723-
res = openml.runs.functions._run_model_on_fold(clf, task, 0, 0, 0, True)
724-
725-
arff_datacontent, arff_tracecontent, user_defined_measures, model = res
723+
res = openml.runs.functions._run_task_get_arffcontent(clf, task)
724+
arff_datacontent, arff_tracecontent, _, fold_evaluations, sample_evaluations = res
726725
# predictions
727726
self.assertIsInstance(arff_datacontent, list)
728727
# trace. SGD does not produce any
729-
self.assertIsInstance(arff_tracecontent, list)
730-
self.assertEquals(len(arff_tracecontent), 0)
731-
732-
fold_evaluations = collections.defaultdict(lambda: collections.defaultdict(dict))
733-
for measure in user_defined_measures:
734-
fold_evaluations[measure][0][0] = user_defined_measures[measure]
728+
self.assertIsInstance(arff_tracecontent, type(None))
735729

736730
self._check_fold_evaluations(fold_evaluations, num_repeats, num_folds)
737731

@@ -755,18 +749,24 @@ def test__run_task_get_arffcontent(self):
755749
self.assertIn(arff_line[7], ['won', 'nowin'])
756750

757751
def test__run_model_on_fold(self):
758-
task = openml.tasks.get_task(11)
759-
num_instances = 3196
752+
task = openml.tasks.get_task(7)
753+
num_instances = 1054
760754
num_folds = 1
761755
num_repeats = 1
762756

763757
clf = SGDClassifier(loss='log', random_state=1)
764-
res = openml.runs.functions._run_task_get_arffcontent(clf, task)
765-
arff_datacontent, arff_tracecontent, _, fold_evaluations, sample_evaluations = res
758+
res = openml.runs.functions._run_model_on_fold(clf, task, 0, 0, 0, True)
759+
760+
arff_datacontent, arff_tracecontent, user_defined_measures, model = res
766761
# predictions
767762
self.assertIsInstance(arff_datacontent, list)
768763
# trace. SGD does not produce any
769-
self.assertIsInstance(arff_tracecontent, type(None))
764+
self.assertIsInstance(arff_tracecontent, list)
765+
self.assertEquals(len(arff_tracecontent), 0)
766+
767+
fold_evaluations = collections.defaultdict(lambda: collections.defaultdict(dict))
768+
for measure in user_defined_measures:
769+
fold_evaluations[measure][0][0] = user_defined_measures[measure]
770770

771771
self._check_fold_evaluations(fold_evaluations, num_repeats, num_folds)
772772

@@ -789,7 +789,6 @@ def test__run_model_on_fold(self):
789789
self.assertIn(arff_line[6], ['won', 'nowin'])
790790
self.assertIn(arff_line[7], ['won', 'nowin'])
791791

792-
793792
def test__create_trace_from_arff(self):
794793
with open(self.static_cache_dir + '/misc/trace.arff', 'r') as arff_file:
795794
trace_arff = arff.load(arff_file)

0 commit comments

Comments
 (0)