|
1 | | -import sys |
2 | 1 | import arff |
3 | | -import time |
| 2 | +import json |
4 | 3 | import random |
| 4 | +import time |
| 5 | +import sys |
5 | 6 |
|
6 | 7 | import numpy as np |
7 | 8 |
|
|
12 | 13 |
|
13 | 14 | from openml.testing import TestBase |
14 | 15 | from openml.runs.functions import _run_task_get_arffcontent, \ |
15 | | - _get_seeded_model, _run_exists |
| 16 | + _get_seeded_model, _run_exists, _extract_arfftrace, \ |
| 17 | + _extract_arfftrace_attributes |
16 | 18 |
|
17 | 19 | from sklearn.naive_bayes import GaussianNB |
18 | 20 | from sklearn.model_selection._search import BaseSearchCV |
@@ -144,7 +146,7 @@ def test_run_and_upload(self): |
144 | 146 | # - openml.runs.run_task() |
145 | 147 | # - openml.runs.OpenMLRun.publish() |
146 | 148 | # - openml.runs.initialize_model() |
147 | | - # - [implicitly] openml.setups.initialize_model_from_setup() |
| 149 | + # - [implicitly] openml.setups.initialize_model() |
148 | 150 | # - openml.runs.initialize_model_from_trace() |
149 | 151 | task_id = 119 # diabates dataset |
150 | 152 | num_test_instances = 253 # 33% holdout task |
@@ -326,6 +328,56 @@ def test_get_seeded_model_raises(self): |
326 | 328 | for clf in randomized_clfs: |
327 | 329 | self.assertRaises(ValueError, _get_seeded_model, model=clf, seed=42) |
328 | 330 |
|
| 331 | + def test__extract_arfftrace(self): |
| 332 | + param_grid = {"max_depth": [3, None], |
| 333 | + "max_features": [1, 2, 3, 4], |
| 334 | + "bootstrap": [True, False], |
| 335 | + "criterion": ["gini", "entropy"]} |
| 336 | + num_iters = 10 |
| 337 | + task = openml.tasks.get_task(20) |
| 338 | + clf = RandomizedSearchCV(RandomForestClassifier(), param_grid, num_iters) |
| 339 | + # just run the task |
| 340 | + train, _ = task.get_train_test_split_indices(0, 0) |
| 341 | + X, y = task.get_X_and_y() |
| 342 | + clf.fit(X[train], y[train]) |
| 343 | + |
| 344 | + trace_attribute_list = _extract_arfftrace_attributes(clf) |
| 345 | + trace_list = _extract_arfftrace(clf, 0, 0) |
| 346 | + self.assertIsInstance(trace_attribute_list, list) |
| 347 | + self.assertEquals(len(trace_attribute_list), 5 + len(param_grid)) |
| 348 | + self.assertIsInstance(trace_list, list) |
| 349 | + self.assertEquals(len(trace_list), num_iters) |
| 350 | + |
| 351 | + # found parameters |
| 352 | + optimized_params = set() |
| 353 | + |
| 354 | + for att_idx in range(len(trace_attribute_list)): |
| 355 | + att_type = trace_attribute_list[att_idx][1] |
| 356 | + att_name = trace_attribute_list[att_idx][0] |
| 357 | + if att_name.startswith("parameter_"): |
| 358 | + # add this to the found parameters |
| 359 | + param_name = att_name[len("parameter_"):] |
| 360 | + optimized_params.add(param_name) |
| 361 | + |
| 362 | + for line_idx in range(len(trace_list)): |
| 363 | + val = json.loads(trace_list[line_idx][att_idx]) |
| 364 | + legal_values = param_grid[param_name] |
| 365 | + self.assertIn(val, legal_values) |
| 366 | + else: |
| 367 | + # repeat, fold, itt, bool |
| 368 | + for line_idx in range(len(trace_list)): |
| 369 | + val = trace_list[line_idx][att_idx] |
| 370 | + if isinstance(att_type, list): |
| 371 | + self.assertIn(val, att_type) |
| 372 | + elif att_name in ['repeat', 'fold', 'iteration']: |
| 373 | + self.assertIsInstance(trace_list[line_idx][att_idx], int) |
| 374 | + else: # att_type = real |
| 375 | + self.assertIsInstance(trace_list[line_idx][att_idx], float) |
| 376 | + |
| 377 | + |
| 378 | + self.assertEqual(param_grid.keys(), optimized_params) |
| 379 | + |
| 380 | + |
329 | 381 | def test_run_with_classifiers_in_param_grid(self): |
330 | 382 | task = openml.tasks.get_task(115) |
331 | 383 |
|
|
0 commit comments