@@ -64,7 +64,11 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
6464 self .url = url
6565 self .default_target_attribute = default_target_attribute
6666 self .row_id_attribute = row_id_attribute
67- self .ignore_attributes = list (ignore_attribute ) if ignore_attribute is not None else None # TODO: check
67+ self .ignore_attributes = None
68+ if isinstance (ignore_attribute , str ):
69+ self .ignore_attributes = [ignore_attribute ]
70+ elif isinstance (ignore_attribute , list ):
71+ self .ignore_attributes = ignore_attribute
6872 self .version_label = version_label
6973 self .citation = citation
7074 self .tag = tag
@@ -311,6 +315,8 @@ def retrieve_class_labels(self, target_name='class'):
311315
312316 def get_features_by_type (self , data_type , exclude = None , exclude_ignore_attributes = True ):
313317 assert data_type in OpenMLDataFeature .LEGAL_DATA_TYPES , "Illegal feature type requested"
318+ if self .ignore_attributes is not None :
319+ assert type (self .ignore_attributes ) is list , "ignore_attributes should be a list"
314320 if exclude is not None :
315321 assert type (exclude ) is list , "Exclude should be a list"
316322 assert all (isinstance (elem , str ) for elem in exclude ), "Exclude should be a list of strings"
@@ -319,7 +325,6 @@ def get_features_by_type(self, data_type, exclude=None, exclude_ignore_attribute
319325 to_exclude .extend (exclude )
320326 if exclude_ignore_attributes and self .ignore_attributes is not None :
321327 to_exclude .extend (self .ignore_attributes )
322- print (to_exclude )
323328
324329 result = []
325330 offset = 0
0 commit comments