Skip to content

Commit d0e0638

Browse files
committed
made dataset.get_features_of_type aware of features that will be removed
1 parent e5b23ed commit d0e0638

2 files changed

Lines changed: 26 additions & 15 deletions

File tree

openml/datasets/dataset.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
6464
self.url = url
6565
self.default_target_attribute = default_target_attribute
6666
self.row_id_attribute = row_id_attribute
67-
self.ignore_attributes = ignore_attribute
67+
self.ignore_attributes = list(ignore_attribute) if ignore_attribute is not None else None # TODO: check
6868
self.version_label = version_label
6969
self.citation = citation
7070
self.tag = tag
@@ -218,10 +218,7 @@ def get_data(self, target=None, target_dtype=int, include_row_id=False,
218218
if not self.ignore_attributes:
219219
pass
220220
else:
221-
if is_string(self.ignore_attributes):
222-
to_exclude.append(self.ignore_attributes)
223-
else:
224-
to_exclude.extend(self.ignore_attributes)
221+
to_exclude.extend(self.ignore_attributes)
225222

226223
if len(to_exclude) > 0:
227224
logger.info("Going to remove the following attributes:"
@@ -311,17 +308,29 @@ def retrieve_class_labels(self, target_name='class'):
311308
else:
312309
return None
313310

314-
def get_features_by_type(self, data_type, exclude=None):
311+
312+
def get_features_by_type(self, data_type, exclude=None, exclude_ignore_attributes=True):
315313
assert data_type in OpenMLDataFeature.LEGAL_DATA_TYPES, "Illegal feature type requested"
316314
if exclude is not None:
317-
assert type(exclude) is list, "Exclude should be a list of indeces"
315+
assert type(exclude) is list, "Exclude should be a list"
316+
assert all(isinstance(elem, str) for elem in exclude), "Exclude should be a list of strings"
317+
to_exclude = []
318+
if exclude is not None:
319+
to_exclude.extend(exclude)
320+
if exclude_ignore_attributes and self.ignore_attributes is not None:
321+
to_exclude.extend(self.ignore_attributes)
322+
print(to_exclude)
318323

319324
result = []
325+
offset = 0
326+
# this function assumes that everything in to_exclude will be 'excluded' from the dataset (hence the offset)
320327
for idx in self.features:
321-
# in many cases we want to exclude, for example, the target feature
322-
if exclude is None or idx not in exclude:
328+
name = self.features[idx].name
329+
if name in to_exclude:
330+
offset += 1
331+
else:
323332
if self.features[idx].data_type == data_type:
324-
result.append(idx)
333+
result.append(idx-offset)
325334
return result
326335

327336
def publish(self):

tests/test_utils/test_conditionalimputer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,14 @@ def _do_test(self, dataset, X, nominal_indices, clf):
3434
return X_prime
3535

3636
def test_impute_indices(self):
37-
task_ids = [2, 24, 42, 59]
37+
task_ids = [2, 24, 41, 42, 59]
3838

3939
for task_id in task_ids:
4040
task = openml.tasks.get_task(task_id)
4141
dataset = task.get_dataset()
4242
X, _ = dataset.get_data(target=task.target_name)
43-
nominal_indices = dataset.get_features_by_type('nominal', exclude=[len(dataset.features)-1])
43+
nominal_indices = dataset.get_features_by_type('nominal', exclude=[task.target_name])
44+
# print("task id %d indices %s" %(task_id, str(nominal_indices)))
4445
clf = ConditionalImputer(strategy="median",
4546
strategy_nominal="most_frequent",
4647
categorical_features=nominal_indices,
@@ -50,13 +51,14 @@ def test_impute_indices(self):
5051

5152

5253
def test_impute_smart(self):
53-
task_ids = [2, 24, 42, 59]
54+
task_ids = [2, 24, 41, 42, 59]
5455

5556
for task_id in task_ids:
5657
task = openml.tasks.get_task(task_id)
5758
dataset = task.get_dataset()
5859
X, _ = dataset.get_data(target=task.target_name)
59-
nominal_indices = dataset.get_features_by_type('nominal', exclude=[len(dataset.features)-1])
60+
nominal_indices = dataset.get_features_by_type('nominal', exclude=[task.target_name])
61+
# print("task id %d indices %s" %(task_id, str(nominal_indices)))
6062
clf = ConditionalImputer(strategy="median",
6163
strategy_nominal="most_frequent",
6264
categorical_features=None,
@@ -71,7 +73,7 @@ def test_impute_with_constant(self):
7173
task = openml.tasks.get_task(task_id)
7274
dataset = task.get_dataset()
7375
X, _ = dataset.get_data(target=task.target_name)
74-
nominal_indices = dataset.get_features_by_type('nominal', exclude=[len(dataset.features) - 1])
76+
nominal_indices = dataset.get_features_by_type('nominal', exclude=[task.target_name])
7577
clf = ConditionalImputer(strategy="median",
7678
strategy_nominal="most_frequent",
7779
categorical_features=None,

0 commit comments

Comments
 (0)