Skip to content

Commit 61c113c

Browse files
committed
added unit test for learning curve task
1 parent edbad39 commit 61c113c

1 file changed

Lines changed: 67 additions & 1 deletion

File tree

tests/test_runs/test_run_functions.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ def _remove_random_state(flow):
157157

158158
return run
159159

160-
161160
def _check_fold_evaluations(self, fold_evaluations, num_repeats, num_folds, max_time_allowed=60000):
162161
'''
163162
Checks whether the right timing measures are attached to the run (before upload).
@@ -184,6 +183,36 @@ def _check_fold_evaluations(self, fold_evaluations, num_repeats, num_folds, max_
184183
self.assertGreater(evaluation, 0) # should take at least one millisecond (?)
185184
self.assertLess(evaluation, max_time_allowed)
186185

186+
187+
def _check_sample_evaluations(self, sample_evaluations, num_repeats, num_folds, num_samples, max_time_allowed=60000):
188+
'''
189+
Checks whether the right timing measures are attached to the run (before upload).
190+
Test is only performed for versions >= Python3.3
191+
192+
In case of check_n_jobs(clf) == false, please do not perform this check (check this
193+
condition outside of this function. )
194+
default max_time_allowed (per fold, in milli seconds) = 1 minute, quite pessimistic
195+
'''
196+
timing_measures = {'usercpu_time_millis_testing', 'usercpu_time_millis_training', 'usercpu_time_millis'}
197+
198+
self.assertIsInstance(sample_evaluations, dict)
199+
if sys.version_info[:2] >= (3, 3):
200+
self.assertEquals(set(sample_evaluations.keys()), timing_measures)
201+
for measure in timing_measures:
202+
num_rep_entrees = len(sample_evaluations[measure])
203+
self.assertEquals(num_rep_entrees, num_repeats)
204+
for rep in range(num_rep_entrees):
205+
num_fold_entrees = len(sample_evaluations[measure][rep])
206+
self.assertEquals(num_fold_entrees, num_folds)
207+
for fold in range(num_fold_entrees):
208+
num_sample_entrees = len(sample_evaluations[measure][rep][fold])
209+
self.assertEquals(num_sample_entrees, num_samples)
210+
for sample in range(num_sample_entrees):
211+
evaluation = sample_evaluations[measure][rep][fold][sample]
212+
self.assertIsInstance(evaluation, float)
213+
self.assertGreater(evaluation, 0) # should take at least one millisecond (?)
214+
self.assertLess(evaluation, max_time_allowed)
215+
187216
def test_run_regression_on_classif_task(self):
188217
task_id = 115
189218

@@ -295,6 +324,43 @@ def test_run_and_upload(self):
295324
self._check_fold_evaluations(run.fold_evaluations, 1, num_folds)
296325
pass
297326

327+
def test_learning_curve_task(self):
328+
task_id = 801 # diabates dataset
329+
num_test_instances = 6144 # for learning curve
330+
num_repeats = 1
331+
num_folds = 10
332+
num_samples = 8
333+
334+
clfs = []
335+
random_state_fixtures = []
336+
337+
#nb = GaussianNB()
338+
#clfs.append(nb)
339+
#random_state_fixtures.append('62501')
340+
341+
pipeline1 = Pipeline(steps=[('scaler', StandardScaler(with_mean=False)),
342+
('dummy', DummyClassifier(strategy='prior'))])
343+
clfs.append(pipeline1)
344+
random_state_fixtures.append('62501')
345+
346+
pipeline2 = Pipeline(steps=[('Imputer', Imputer(strategy='median')),
347+
('VarianceThreshold', VarianceThreshold()),
348+
('Estimator', RandomizedSearchCV(
349+
DecisionTreeClassifier(),
350+
{'min_samples_split': [2 ** x for x in range(1, 7 + 1)],
351+
'min_samples_leaf': [2 ** x for x in range(0, 6 + 1)]},
352+
cv=3, n_iter=10))])
353+
clfs.append(pipeline2)
354+
random_state_fixtures.append('62501')
355+
356+
357+
for clf, rsv in zip(clfs, random_state_fixtures):
358+
run = self._perform_run(task_id, num_test_instances, clf,
359+
random_state_value=rsv)
360+
361+
# todo: check if runtime is present
362+
self._check_sample_evaluations(run.sample_evaluations, num_repeats, num_folds, num_samples)
363+
298364
def test_initialize_cv_from_run(self):
299365
randomsearch = RandomizedSearchCV(
300366
RandomForestClassifier(n_estimators=5),

0 commit comments

Comments
 (0)