|
27 | 27 | from sklearn.feature_selection import VarianceThreshold |
28 | 28 | from sklearn.linear_model import LogisticRegression, SGDClassifier, \ |
29 | 29 | LinearRegression |
| 30 | +from sklearn.neural_network import MLPClassifier |
30 | 31 | from sklearn.ensemble import RandomForestClassifier, BaggingClassifier |
31 | 32 | from sklearn.svm import SVC, LinearSVC |
32 | 33 | from sklearn.model_selection import RandomizedSearchCV, GridSearchCV, \ |
@@ -617,18 +618,21 @@ def test__get_seeded_model_raises(self): |
617 | 618 | self.assertRaises(ValueError, _get_seeded_model, model=clf, seed=42) |
618 | 619 |
|
619 | 620 | def test__extract_arfftrace(self): |
620 | | - param_grid = {"max_depth": [3, None], |
621 | | - "max_features": [1, 2, 3, 4], |
622 | | - "bootstrap": [True, False], |
623 | | - "criterion": ["gini", "entropy"]} |
| 621 | + param_grid = {"hidden_layer_sizes": [[5, 5], [10, 10], [20, 20]], |
| 622 | + "activation" : ['identity', 'logistic', 'tanh', 'relu'], |
| 623 | + "learning_rate_init": [0.1, 0.01, 0.001, 0.0001], |
| 624 | + "max_iter": [10, 20, 40, 80]} |
624 | 625 | num_iters = 10 |
625 | 626 | task = openml.tasks.get_task(20) |
626 | | - clf = RandomizedSearchCV(RandomForestClassifier(), param_grid, num_iters) |
| 627 | + clf = RandomizedSearchCV(MLPClassifier(), param_grid, num_iters) |
627 | 628 | # just run the task |
628 | 629 | train, _ = task.get_train_test_split_indices(0, 0) |
629 | 630 | X, y = task.get_X_and_y() |
630 | 631 | clf.fit(X[train], y[train]) |
631 | 632 |
|
| 633 | + # check num layers of MLP |
| 634 | + self.assertIn(clf.best_estimator_.hidden_layer_sizes, param_grid['hidden_layer_sizes']) |
| 635 | + |
632 | 636 | trace_attribute_list = _extract_arfftrace_attributes(clf) |
633 | 637 | trace_list = _extract_arfftrace(clf, 0, 0) |
634 | 638 | self.assertIsInstance(trace_attribute_list, list) |
@@ -662,7 +666,6 @@ def test__extract_arfftrace(self): |
662 | 666 | else: # att_type = real |
663 | 667 | self.assertIsInstance(trace_list[line_idx][att_idx], float) |
664 | 668 |
|
665 | | - |
666 | 669 | self.assertEqual(set(param_grid.keys()), optimized_params) |
667 | 670 |
|
668 | 671 | def test__prediction_to_row(self): |
|
0 commit comments