|
1 | 1 | import arff |
| 2 | +import collections |
2 | 3 | import json |
3 | 4 | import random |
4 | 5 | import time |
@@ -714,9 +715,49 @@ def test_run_with_classifiers_in_param_grid(self): |
714 | 715 |
|
715 | 716 | def test__run_task_get_arffcontent(self): |
716 | 717 | task = openml.tasks.get_task(7) |
717 | | - class_labels = task.class_labels |
| 718 | + num_instances = 320 |
| 719 | + num_folds = 1 |
| 720 | + num_repeats = 1 |
| 721 | + |
| 722 | + clf = SGDClassifier(loss='log', random_state=1) |
| 723 | + res = openml.runs.functions._run_model_on_fold(clf, task, 0, 0, 0, True) |
| 724 | + |
| 725 | + arff_datacontent, arff_tracecontent, user_defined_measures, model = res |
| 726 | + # predictions |
| 727 | + self.assertIsInstance(arff_datacontent, list) |
| 728 | + # trace. SGD does not produce any |
| 729 | + self.assertIsInstance(arff_tracecontent, list) |
| 730 | + self.assertEquals(len(arff_tracecontent), 0) |
| 731 | + |
| 732 | + fold_evaluations = collections.defaultdict(lambda: collections.defaultdict(dict)) |
| 733 | + for measure in user_defined_measures: |
| 734 | + fold_evaluations[measure][0][0] = user_defined_measures[measure] |
| 735 | + |
| 736 | + self._check_fold_evaluations(fold_evaluations, num_repeats, num_folds) |
| 737 | + |
| 738 | + # 10 times 10 fold CV of 150 samples |
| 739 | + self.assertEqual(len(arff_datacontent), num_instances * num_repeats) |
| 740 | + for arff_line in arff_datacontent: |
| 741 | + # check number columns |
| 742 | + self.assertEqual(len(arff_line), 8) |
| 743 | + # check repeat |
| 744 | + self.assertGreaterEqual(arff_line[0], 0) |
| 745 | + self.assertLessEqual(arff_line[0], num_repeats - 1) |
| 746 | + # check fold |
| 747 | + self.assertGreaterEqual(arff_line[1], 0) |
| 748 | + self.assertLessEqual(arff_line[1], num_folds - 1) |
| 749 | + # check row id |
| 750 | + self.assertGreaterEqual(arff_line[2], 0) |
| 751 | + self.assertLessEqual(arff_line[2], num_instances - 1) |
| 752 | + # check confidences |
| 753 | + self.assertAlmostEqual(sum(arff_line[4:6]), 1.0) |
| 754 | + self.assertIn(arff_line[6], ['won', 'nowin']) |
| 755 | + self.assertIn(arff_line[7], ['won', 'nowin']) |
| 756 | + |
| 757 | + def test__run_model_on_fold(self): |
| 758 | + task = openml.tasks.get_task(11) |
718 | 759 | num_instances = 3196 |
719 | | - num_folds = 10 |
| 760 | + num_folds = 1 |
720 | 761 | num_repeats = 1 |
721 | 762 |
|
722 | 763 | clf = SGDClassifier(loss='log', random_state=1) |
@@ -748,6 +789,7 @@ def test__run_task_get_arffcontent(self): |
748 | 789 | self.assertIn(arff_line[6], ['won', 'nowin']) |
749 | 790 | self.assertIn(arff_line[7], ['won', 'nowin']) |
750 | 791 |
|
| 792 | + |
751 | 793 | def test__create_trace_from_arff(self): |
752 | 794 | with open(self.static_cache_dir + '/misc/trace.arff', 'r') as arff_file: |
753 | 795 | trace_arff = arff.load(arff_file) |
|
0 commit comments