Skip to content

Commit 0dbe3ce

Browse files
committed
unit test for extract_arff_trace(_attributes)
1 parent 04d5677 commit 0dbe3ce

2 files changed

Lines changed: 58 additions & 4 deletions

File tree

tests/test_flows/test_sklearn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,8 @@ def test_gaussian_process(self):
461461
sklearn_to_flow, gp)
462462

463463
def test_error_on_adding_component_multiple_times_to_flow(self):
464+
# this function implicitly checks
465+
# - openml.flows._check_multiple_occurence_of_component_in_flow()
464466
pca = sklearn.decomposition.PCA()
465467
pca2 = sklearn.decomposition.PCA()
466468
pipeline = sklearn.pipeline.Pipeline((('pca1', pca), ('pca2', pca2)))

tests/test_runs/test_run_functions.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import sys
21
import arff
3-
import time
2+
import json
43
import random
4+
import time
5+
import sys
56

67
import numpy as np
78

@@ -12,7 +13,8 @@
1213

1314
from openml.testing import TestBase
1415
from openml.runs.functions import _run_task_get_arffcontent, \
15-
_get_seeded_model, _run_exists
16+
_get_seeded_model, _run_exists, _extract_arfftrace, \
17+
_extract_arfftrace_attributes
1618

1719
from sklearn.naive_bayes import GaussianNB
1820
from sklearn.model_selection._search import BaseSearchCV
@@ -144,7 +146,7 @@ def test_run_and_upload(self):
144146
# - openml.runs.run_task()
145147
# - openml.runs.OpenMLRun.publish()
146148
# - openml.runs.initialize_model()
147-
# - [implicitly] openml.setups.initialize_model_from_setup()
149+
# - [implicitly] openml.setups.initialize_model()
148150
# - openml.runs.initialize_model_from_trace()
149151
task_id = 119 # diabates dataset
150152
num_test_instances = 253 # 33% holdout task
@@ -326,6 +328,56 @@ def test_get_seeded_model_raises(self):
326328
for clf in randomized_clfs:
327329
self.assertRaises(ValueError, _get_seeded_model, model=clf, seed=42)
328330

331+
def test__extract_arfftrace(self):
332+
param_grid = {"max_depth": [3, None],
333+
"max_features": [1, 2, 3, 4],
334+
"bootstrap": [True, False],
335+
"criterion": ["gini", "entropy"]}
336+
num_iters = 10
337+
task = openml.tasks.get_task(20)
338+
clf = RandomizedSearchCV(RandomForestClassifier(), param_grid, num_iters)
339+
# just run the task
340+
train, _ = task.get_train_test_split_indices(0, 0)
341+
X, y = task.get_X_and_y()
342+
clf.fit(X[train], y[train])
343+
344+
trace_attribute_list = _extract_arfftrace_attributes(clf)
345+
trace_list = _extract_arfftrace(clf, 0, 0)
346+
self.assertIsInstance(trace_attribute_list, list)
347+
self.assertEquals(len(trace_attribute_list), 5 + len(param_grid))
348+
self.assertIsInstance(trace_list, list)
349+
self.assertEquals(len(trace_list), num_iters)
350+
351+
# found parameters
352+
optimized_params = set()
353+
354+
for att_idx in range(len(trace_attribute_list)):
355+
att_type = trace_attribute_list[att_idx][1]
356+
att_name = trace_attribute_list[att_idx][0]
357+
if att_name.startswith("parameter_"):
358+
# add this to the found parameters
359+
param_name = att_name[len("parameter_"):]
360+
optimized_params.add(param_name)
361+
362+
for line_idx in range(len(trace_list)):
363+
val = json.loads(trace_list[line_idx][att_idx])
364+
legal_values = param_grid[param_name]
365+
self.assertIn(val, legal_values)
366+
else:
367+
# repeat, fold, itt, bool
368+
for line_idx in range(len(trace_list)):
369+
val = trace_list[line_idx][att_idx]
370+
if isinstance(att_type, list):
371+
self.assertIn(val, att_type)
372+
elif att_name in ['repeat', 'fold', 'iteration']:
373+
self.assertIsInstance(trace_list[line_idx][att_idx], int)
374+
else: # att_type = real
375+
self.assertIsInstance(trace_list[line_idx][att_idx], float)
376+
377+
378+
self.assertEqual(param_grid.keys(), optimized_params)
379+
380+
329381
def test_run_with_classifiers_in_param_grid(self):
330382
task = openml.tasks.get_task(115)
331383

0 commit comments

Comments
 (0)