Skip to content

Commit f4387d6

Browse files
authored
Merge pull request #399 from openml/fix373
fixes #373 + unit test
2 parents a1e9368 + 02560d7 commit f4387d6

3 files changed

Lines changed: 20 additions & 12 deletions

File tree

openml/runs/functions.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -594,11 +594,16 @@ def _extract_arfftrace_attributes(model):
594594
for key in model.cv_results_:
595595
if key.startswith('param_'):
596596
# supported types should include all types, including bool, int float
597-
supported_types = (bool, int, float, six.string_types)
598-
if all(isinstance(i, supported_types) or i is None for i in model.cv_results_[key]):
599-
type = 'STRING'
600-
else:
601-
raise TypeError('Unsupported param type in param grid')
597+
supported_basic_types = (bool, int, float, six.string_types)
598+
for param_value in model.cv_results_[key]:
599+
if isinstance(param_value, supported_basic_types) or param_value is None:
600+
# basic string values
601+
type = 'STRING'
602+
elif isinstance(param_value, list) and all(isinstance(i, int) for i in param_value):
603+
# list of integers
604+
type = 'STRING'
605+
else:
606+
raise TypeError('Unsupported param type in param grid: %s' %key)
602607

603608
# we renamed the attribute param to parameter, as this is a required
604609
# OpenML convention

tests/test_runs/test_run_functions.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from sklearn.feature_selection import VarianceThreshold
2828
from sklearn.linear_model import LogisticRegression, SGDClassifier, \
2929
LinearRegression
30+
from sklearn.neural_network import MLPClassifier
3031
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier
3132
from sklearn.svm import SVC, LinearSVC
3233
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, \
@@ -617,18 +618,21 @@ def test__get_seeded_model_raises(self):
617618
self.assertRaises(ValueError, _get_seeded_model, model=clf, seed=42)
618619

619620
def test__extract_arfftrace(self):
620-
param_grid = {"max_depth": [3, None],
621-
"max_features": [1, 2, 3, 4],
622-
"bootstrap": [True, False],
623-
"criterion": ["gini", "entropy"]}
621+
param_grid = {"hidden_layer_sizes": [[5, 5], [10, 10], [20, 20]],
622+
"activation" : ['identity', 'logistic', 'tanh', 'relu'],
623+
"learning_rate_init": [0.1, 0.01, 0.001, 0.0001],
624+
"max_iter": [10, 20, 40, 80]}
624625
num_iters = 10
625626
task = openml.tasks.get_task(20)
626-
clf = RandomizedSearchCV(RandomForestClassifier(), param_grid, num_iters)
627+
clf = RandomizedSearchCV(MLPClassifier(), param_grid, num_iters)
627628
# just run the task
628629
train, _ = task.get_train_test_split_indices(0, 0)
629630
X, y = task.get_X_and_y()
630631
clf.fit(X[train], y[train])
631632

633+
# check num layers of MLP
634+
self.assertIn(clf.best_estimator_.hidden_layer_sizes, param_grid['hidden_layer_sizes'])
635+
632636
trace_attribute_list = _extract_arfftrace_attributes(clf)
633637
trace_list = _extract_arfftrace(clf, 0, 0)
634638
self.assertIsInstance(trace_attribute_list, list)
@@ -662,7 +666,6 @@ def test__extract_arfftrace(self):
662666
else: # att_type = real
663667
self.assertIsInstance(trace_list[line_idx][att_idx], float)
664668

665-
666669
self.assertEqual(set(param_grid.keys()), optimized_params)
667670

668671
def test__prediction_to_row(self):

tests/test_setups/test_setup_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_setup_list_filter_flow(self):
138138
self.assertEquals(setups[setup_id].flow_id, flow_id)
139139

140140
def test_list_setups_empty(self):
141-
setups = openml.setups.list_setups(setup=[-1])
141+
setups = openml.setups.list_setups(setup=[0])
142142
if len(setups) > 0:
143143
raise ValueError('UnitTest Outdated, got somehow results')
144144

0 commit comments

Comments
 (0)