|
1 | 1 | import sys |
| 2 | +import arff |
| 3 | +import time |
2 | 4 |
|
3 | 5 | import openml |
4 | 6 | import openml.exceptions |
| 7 | +import openml._api_calls |
5 | 8 |
|
6 | 9 | from openml.testing import TestBase |
7 | 10 | from openml.runs.functions import _run_task_get_arffcontent |
@@ -29,12 +32,35 @@ class TestRun(TestBase): |
29 | 32 | def _check_serialized_optimized_run(self, run_id): |
30 | 33 | run = openml.runs.get_run(run_id) |
31 | 34 | task = openml.tasks.get_task(run.task_id) |
32 | | - trace = openml.runs.get_run_trace(run_id) |
| 35 | + |
33 | 36 | # TODO: assert holdout task |
34 | 37 |
|
35 | | - model = openml.runs.initialize_model_from_trace(run_id, 0, 0) |
36 | | - # TODO: implement testcase |
37 | | - |
| 38 | + # downloads the predictions of the old task |
| 39 | + predictions_url = openml._api_calls.fileid_to_url(run.output_files['predictions']) |
| 40 | + predictions = arff.loads(openml._api_calls._read_url(predictions_url)) |
| 41 | + |
| 42 | + # downloads the best model based on the optimization trace |
| 43 | + print(run_id) |
| 44 | + time.sleep(60) |
| 45 | + model_prime = openml.runs.initialize_model_from_trace(run_id, 0, 0) |
| 46 | + |
| 47 | + run_prime = openml.runs.run_task(task, model_prime, avoid_duplicate_runs=False) |
| 48 | + predictions_prime = run_prime._generate_arff_dict() |
| 49 | + |
| 50 | + print(model_prime) |
| 51 | + |
| 52 | + self.assertEquals(len(predictions_prime['data']), len(predictions['data'])) |
| 53 | + |
| 54 | + # The original search model does not submit confidence bounds, |
| 55 | + # so we can not compare the arff line |
| 56 | + compare_slice = [0, 1, 2, -1, -2] |
| 57 | + for idx in range(len(predictions['data'])): |
| 58 | + # depends on the assumption "predictions are in same order" |
| 59 | + # that does not necessarily hold. |
| 60 | + # But with the current code base, it holds. |
| 61 | + for col_idx in compare_slice: |
| 62 | + self.assertEquals(predictions['data'][idx][col_idx], predictions_prime['data'][idx][col_idx]) |
| 63 | + |
38 | 64 | return True |
39 | 65 |
|
40 | 66 |
|
@@ -118,8 +144,8 @@ def test_run_optimize_randomforest_diabetes(self): |
118 | 144 | run = self._perform_run(task_id, num_test_instances, random_search) |
119 | 145 | self.assertEqual(len(run.trace_content), num_iterations * num_folds) |
120 | 146 |
|
121 | | - # res = self._check_serialized_optimized_run(run.run_id) |
122 | | - # self.assertTrue(res) |
| 147 | + res = self._check_serialized_optimized_run(run.run_id) |
| 148 | + self.assertTrue(res) |
123 | 149 |
|
124 | 150 | def test_run_optimize_bagging_diabetes(self): |
125 | 151 | task_id = 119 |
|
0 commit comments