Skip to content

Commit 6147c4f

Browse files
committed
fixes #373 + unit test
1 parent 194706d commit 6147c4f

2 files changed

Lines changed: 7 additions & 6 deletions

File tree

openml/runs/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def _extract_arfftrace_attributes(model):
594594
for key in model.cv_results_:
595595
if key.startswith('param_'):
596596
# supported types should include all types, including bool, int float
597-
supported_types = (bool, int, float, six.string_types)
597+
supported_types = (bool, int, float, six.string_types, tuple)
598598
if all(isinstance(i, supported_types) or i is None for i in model.cv_results_[key]):
599599
type = 'STRING'
600600
else:

tests/test_runs/test_run_functions.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from sklearn.feature_selection import VarianceThreshold
2828
from sklearn.linear_model import LogisticRegression, SGDClassifier, \
2929
LinearRegression
30+
from sklearn.neural_network import MLPClassifier
3031
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier
3132
from sklearn.svm import SVC, LinearSVC
3233
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, \
@@ -614,13 +615,13 @@ def test__get_seeded_model_raises(self):
614615
self.assertRaises(ValueError, _get_seeded_model, model=clf, seed=42)
615616

616617
def test__extract_arfftrace(self):
617-
param_grid = {"max_depth": [3, None],
618-
"max_features": [1, 2, 3, 4],
619-
"bootstrap": [True, False],
620-
"criterion": ["gini", "entropy"]}
618+
param_grid = {"hidden_layer_sizes": [(5, 5), (10, 10), (20, 20)],
619+
"activation" : ['identity', 'logistic', 'tanh', 'relu'],
620+
"learning_rate_init": [0.1, 0.01, 0.001, 0.0001],
621+
"max_iter": [10, 20, 40, 80]}
621622
num_iters = 10
622623
task = openml.tasks.get_task(20)
623-
clf = RandomizedSearchCV(RandomForestClassifier(), param_grid, num_iters)
624+
clf = RandomizedSearchCV(MLPClassifier(), param_grid, num_iters)
624625
# just run the task
625626
train, _ = task.get_train_test_split_indices(0, 0)
626627
X, y = task.get_X_and_y()

0 commit comments

Comments
 (0)