Skip to content

Commit d11b0bf

Browse files
committed
implemented test for obtaining model from trace
1 parent cb55127 commit d11b0bf

2 files changed

Lines changed: 41 additions & 7 deletions

File tree

openml/runs/functions.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,15 @@ def _create_trace_from_description(xml):
576576
iteration = int(itt['oml:iteration'])
577577
setup_string = json.loads(itt['oml:setup_string'])
578578
evaluation = float(itt['oml:evaluation'])
579-
selected = bool(itt['oml:selected'])
579+
580+
selectedValue = itt['oml:selected']
581+
if selectedValue == 'true':
582+
selected = True
583+
elif selectedValue == 'false':
584+
selected = False
585+
else:
586+
raise ValueError('expected {"true", "false"} value for '\
587+
'selected field, received: %s' %selectedValue)
580588

581589
current = OpenMLTraceIteration(repeat, fold, iteration,
582590
setup_string, evaluation,

tests/test_runs/test_run_functions.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import sys
2+
import arff
3+
import time
24

35
import openml
46
import openml.exceptions
7+
import openml._api_calls
58

69
from openml.testing import TestBase
710
from openml.runs.functions import _run_task_get_arffcontent
@@ -29,12 +32,35 @@ class TestRun(TestBase):
2932
def _check_serialized_optimized_run(self, run_id):
3033
run = openml.runs.get_run(run_id)
3134
task = openml.tasks.get_task(run.task_id)
32-
trace = openml.runs.get_run_trace(run_id)
35+
3336
# TODO: assert holdout task
3437

35-
model = openml.runs.initialize_model_from_trace(run_id, 0, 0)
36-
# TODO: implement testcase
37-
38+
# downloads the predictions of the old task
39+
predictions_url = openml._api_calls.fileid_to_url(run.output_files['predictions'])
40+
predictions = arff.loads(openml._api_calls._read_url(predictions_url))
41+
42+
# downloads the best model based on the optimization trace
43+
print(run_id)
44+
time.sleep(60)
45+
model_prime = openml.runs.initialize_model_from_trace(run_id, 0, 0)
46+
47+
run_prime = openml.runs.run_task(task, model_prime, avoid_duplicate_runs=False)
48+
predictions_prime = run_prime._generate_arff_dict()
49+
50+
print(model_prime)
51+
52+
self.assertEquals(len(predictions_prime['data']), len(predictions['data']))
53+
54+
# The original search model does not submit confidence bounds,
55+
# so we can not compare the arff line
56+
compare_slice = [0, 1, 2, -1, -2]
57+
for idx in range(len(predictions['data'])):
58+
# depends on the assumption "predictions are in same order"
59+
# that does not necessarily hold.
60+
# But with the current code base, it holds.
61+
for col_idx in compare_slice:
62+
self.assertEquals(predictions['data'][idx][col_idx], predictions_prime['data'][idx][col_idx])
63+
3864
return True
3965

4066

@@ -118,8 +144,8 @@ def test_run_optimize_randomforest_diabetes(self):
118144
run = self._perform_run(task_id, num_test_instances, random_search)
119145
self.assertEqual(len(run.trace_content), num_iterations * num_folds)
120146

121-
# res = self._check_serialized_optimized_run(run.run_id)
122-
# self.assertTrue(res)
147+
res = self._check_serialized_optimized_run(run.run_id)
148+
self.assertTrue(res)
123149

124150
def test_run_optimize_bagging_diabetes(self):
125151
task_id = 119

0 commit comments

Comments
 (0)