Skip to content

Commit 8325c72

Browse files
committed
added test cases
1 parent a6163c3 commit 8325c72

1 file changed

Lines changed: 60 additions & 10 deletions

File tree

tests/test_runs/test_run_functions.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22
import arff
33
import time
4+
import random
45

56
import numpy as np
67

@@ -10,7 +11,8 @@
1011
import sklearn
1112

1213
from openml.testing import TestBase
13-
from openml.runs.functions import _run_task_get_arffcontent, _get_seeded_model
14+
from openml.runs.functions import _run_task_get_arffcontent, \
15+
_get_seeded_model, _run_exists
1416

1517
from sklearn.naive_bayes import GaussianNB
1618
from sklearn.model_selection._search import BaseSearchCV
@@ -35,6 +37,21 @@
3537

3638
class TestRun(TestBase):
3739

40+
def _wait_for_processed_run(self, run_id, max_waiting_time_seconds):
41+
# it can take a while for a run to be processed on the OpenML (test) server
42+
# however, sometimes it is good to wait (a bit) for this, to properly test
43+
# a function. In this case, we wait for max_waiting_time_seconds on this
44+
# to happen, probing the server every 10 seconds to speed up the process
45+
46+
# time.time() works in seconds
47+
start_time = time.time()
48+
while time.time() - start_time < max_waiting_time_seconds:
49+
run = openml.runs.get_run(run_id)
50+
if len(run.evaluations) > 0:
51+
return
52+
else:
53+
time.sleep(10)
54+
3855
def _check_serialized_optimized_run(self, run_id):
3956
run = openml.runs.get_run(run_id)
4057
task = openml.tasks.get_task(run.task_id)
@@ -48,15 +65,8 @@ def _check_serialized_optimized_run(self, run_id):
4865
# downloads the best model based on the optimization trace
4966
# suboptimal (slow), and not guaranteed to work if evaluation
5067
# engine is behind. TODO: mock this? We have the arff already on the server
51-
secCount = 0
52-
while secCount < 70:
53-
try:
54-
model_prime = openml.runs.initialize_model_from_trace(run_id, 0, 0)
55-
break
56-
except openml.exceptions.OpenMLServerException:
57-
# probably because openml eval engine has not executed this run yet
58-
time.sleep(10)
59-
secCount += 10
68+
self._wait_for_processed_run(run_id, 80)
69+
model_prime = openml.runs.initialize_model_from_trace(run_id, 0, 0)
6070

6171
run_prime = openml.runs.run_task(task, model_prime, avoid_duplicate_runs=False)
6272
predictions_prime = run_prime._generate_arff_dict()
@@ -202,6 +212,46 @@ def test_initialize_model_from_run(self):
202212
self.assertEquals(flowS.components['VarianceThreshold'].parameters['threshold'], '0.05')
203213
pass
204214

215+
def test_get_run_trace(self):
216+
# get_run_trace is already tested implicitly in test_run_and_publish
217+
# this test is a bit additional.
218+
num_iterations = 10
219+
num_folds = 1
220+
task_id = 119
221+
run_id = None
222+
223+
task = openml.tasks.get_task(task_id)
224+
# IMPORTANT! Do not sentinel this flow. is faster if we don't wait on openml server
225+
clf = RandomizedSearchCV(RandomForestClassifier(random_state=42),
226+
{"max_depth": [3, None],
227+
"max_features": [1, 2, 3, 4],
228+
"bootstrap": [True, False],
229+
"criterion": ["gini", "entropy"]},
230+
num_iterations, random_state=42)
231+
232+
# [START] for speeding up this unit test!
233+
flow = openml.flows.sklearn_to_flow(clf)
234+
flow_exists = openml.flows.flow_exists(flow.name, flow.external_version)
235+
if flow_exists:
236+
flow = openml.flows.get_flow(flow_exists)
237+
setup_exists = openml.setups.setup_exists(flow, clf)
238+
if setup_exists:
239+
# receives a set of runids. These should all be the same.
240+
run_id = random.choice(list(_run_exists(task_id, setup_exists)))
241+
# [END] speeding up unit test
242+
243+
# ensure the run exists ...
244+
if run_id is None:
245+
print("Run not executed yet .. running random search on Random Forest")
246+
# we can be strict about duplicate runs
247+
run = openml.runs.run_task(task, clf, avoid_duplicate_runs=True)
248+
run = run.publish()
249+
self._wait_for_processed_run(run.run_id, 80)
250+
run_id = run.run_id
251+
252+
# now the actual unit test ...
253+
run_trace = openml.runs.get_run_trace(run_id)
254+
self.assertEqual(len(run_trace.trace_iterations), num_iterations * num_folds)
205255

206256
def test_get_seeded_model(self):
207257
# randomized models that are initialized without seeds, can be seeded

0 commit comments

Comments
 (0)