Skip to content

Commit cb55127

Browse files
committed
added setup for testcase
1 parent be15112 commit cb55127

1 file changed

Lines changed: 25 additions & 10 deletions

File tree

tests/test_runs/test_run_functions.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@
2626

2727
class TestRun(TestBase):
2828

29+
def _check_serialized_optimized_run(self, run_id):
30+
run = openml.runs.get_run(run_id)
31+
task = openml.tasks.get_task(run.task_id)
32+
trace = openml.runs.get_run_trace(run_id)
33+
# TODO: assert holdout task
34+
35+
model = openml.runs.initialize_model_from_trace(run_id, 0, 0)
36+
# TODO: implement testcase
37+
38+
return True
39+
40+
2941
def _perform_run(self, task_id, num_instances, clf, check_setup=True):
3042
task = openml.tasks.get_task(task_id)
3143
run = openml.runs.run_task(task, clf, openml.config.avoid_duplicate_runs)
@@ -86,10 +98,10 @@ def test_run_diabetes(self):
8698
downloaded = openml.runs.get_run(res.run_id)
8799
assert('openml-python' in downloaded.tags)
88100

89-
def test_run_optimize_randomforest_iris(self):
90-
task_id = 115
91-
num_instances = 768
92-
num_folds = 10
101+
def test_run_optimize_randomforest_diabetes(self):
102+
task_id = 119
103+
num_test_instances = 253
104+
num_folds = 1
93105
num_iterations = 5
94106

95107
clf = RandomForestClassifier(n_estimators=5)
@@ -103,21 +115,24 @@ def test_run_optimize_randomforest_iris(self):
103115
random_search = RandomizedSearchCV(clf, param_dist, cv=cv,
104116
n_iter=num_iterations)
105117

106-
run = self._perform_run(task_id, num_instances, random_search)
118+
run = self._perform_run(task_id, num_test_instances, random_search)
107119
self.assertEqual(len(run.trace_content), num_iterations * num_folds)
108120

109-
def test_run_optimize_bagging_iris(self):
110-
task_id = 115
111-
num_instances = 768
112-
num_folds = 10
121+
# res = self._check_serialized_optimized_run(run.run_id)
122+
# self.assertTrue(res)
123+
124+
def test_run_optimize_bagging_diabetes(self):
125+
task_id = 119
126+
num_test_instances = 253
127+
num_folds = 1
113128
num_iterations = 9 # (num values for C times gamma)
114129

115130
bag = BaggingClassifier(base_estimator=SVC())
116131
param_dist = {"base_estimator__C": [0.01, 0.1, 10],
117132
"base_estimator__gamma": [0.01, 0.1, 10]}
118133
grid_search = GridSearchCV(bag, param_dist)
119134

120-
run = self._perform_run(task_id, num_instances, grid_search)
135+
run = self._perform_run(task_id, num_test_instances, grid_search)
121136
self.assertEqual(len(run.trace_content), num_iterations * num_folds)
122137

123138
def test_run_pipeline(self):

0 commit comments

Comments
 (0)