Skip to content

Commit db65bdc

Browse files
committed
added list of integers to set of accepted parameter types for arff traces
1 parent c93dd94 commit db65bdc

2 files changed

Lines changed: 14 additions & 7 deletions

File tree

openml/runs/functions.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -594,11 +594,16 @@ def _extract_arfftrace_attributes(model):
594594
for key in model.cv_results_:
595595
if key.startswith('param_'):
596596
# supported types should include all types, including bool, int float
597-
supported_types = (bool, int, float, six.string_types, tuple)
598-
if all(isinstance(i, supported_types) or i is None for i in model.cv_results_[key]):
599-
type = 'STRING'
600-
else:
601-
raise TypeError('Unsupported param type in param grid')
597+
supported_basic_types = (bool, int, float, six.string_types)
598+
for param_value in model.cv_results_[key]:
599+
if isinstance(param_value, supported_basic_types) or param_value is None:
600+
# basic string values
601+
type = 'STRING'
602+
elif isinstance(param_value, list) and all(isinstance(i, int) for i in param_value):
603+
# list of integers
604+
type = 'STRING'
605+
else:
606+
raise TypeError('Unsupported param type in param grid: %s' %key)
602607

603608
# we renamed the attribute param to parameter, as this is a required
604609
# OpenML convention

tests/test_runs/test_run_functions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def test__get_seeded_model_raises(self):
615615
self.assertRaises(ValueError, _get_seeded_model, model=clf, seed=42)
616616

617617
def test__extract_arfftrace(self):
618-
param_grid = {"hidden_layer_sizes": [(5, 5), (10, 10), (20, 20)],
618+
param_grid = {"hidden_layer_sizes": [[5, 5], [10, 10], [20, 20]],
619619
"activation" : ['identity', 'logistic', 'tanh', 'relu'],
620620
"learning_rate_init": [0.1, 0.01, 0.001, 0.0001],
621621
"max_iter": [10, 20, 40, 80]}
@@ -627,6 +627,9 @@ def test__extract_arfftrace(self):
627627
X, y = task.get_X_and_y()
628628
clf.fit(X[train], y[train])
629629

630+
# check num layers of MLP
631+
self.assertIn(clf.best_estimator_.hidden_layer_sizes, param_grid['hidden_layer_sizes'])
632+
630633
trace_attribute_list = _extract_arfftrace_attributes(clf)
631634
trace_list = _extract_arfftrace(clf, 0, 0)
632635
self.assertIsInstance(trace_attribute_list, list)
@@ -660,7 +663,6 @@ def test__extract_arfftrace(self):
660663
else: # att_type = real
661664
self.assertIsInstance(trace_list[line_idx][att_idx], float)
662665

663-
664666
self.assertEqual(set(param_grid.keys()), optimized_params)
665667

666668
def test__prediction_to_row(self):

0 commit comments

Comments
 (0)