22import io
33import logging
44import os
5+ import six
56import sys
67
78import arff
1011import scipy .sparse
1112import xmltodict
1213
14+ from .data_feature import OpenMLDataFeature
1315from ..exceptions import PyOpenMLError
1416
1517if sys .version_info [0 ] >= 3 :
@@ -63,7 +65,15 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
6365 self .url = url
6466 self .default_target_attribute = default_target_attribute
6567 self .row_id_attribute = row_id_attribute
66- self .ignore_attributes = ignore_attribute
68+ self .ignore_attributes = None
69+ if isinstance (ignore_attribute , six .string_types ):
70+ self .ignore_attributes = [ignore_attribute ]
71+ elif isinstance (ignore_attribute , list ):
72+ self .ignore_attributes = ignore_attribute
73+ elif ignore_attribute is None :
74+ pass
75+ else :
76+ raise ValueError ('wrong data type for ignore_attribute. Should be list. ' )
6777 self .version_label = version_label
6878 self .citation = citation
6979 self .tag = tag
@@ -73,7 +83,20 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
7383 self .update_comment = update_comment
7484 self .md5_cheksum = md5_checksum
7585 self .data_file = data_file
76- self .features = features
86+ self .features = None
87+
88+ if features is not None :
89+ self .features = {}
90+ for idx , xmlfeature in enumerate (features ['oml:feature' ]):
91+ feature = OpenMLDataFeature (int (xmlfeature ['oml:index' ]),
92+ xmlfeature ['oml:name' ],
93+ xmlfeature ['oml:data_type' ],
94+ None , #todo add nominal values (currently not in database)
95+ int (xmlfeature ['oml:number_of_missing_values' ]))
96+ if idx != feature .index :
97+ raise ValueError ('Data features not provided in right order' )
98+ self .features [feature .index ] = feature
99+
77100
78101 if data_file is not None :
79102 if self ._data_features_supported ():
@@ -205,10 +228,7 @@ def get_data(self, target=None, target_dtype=int, include_row_id=False,
205228 if not self .ignore_attributes :
206229 pass
207230 else :
208- if is_string (self .ignore_attributes ):
209- to_exclude .append (self .ignore_attributes )
210- else :
211- to_exclude .extend (self .ignore_attributes )
231+ to_exclude .extend (self .ignore_attributes )
212232
213233 if len (to_exclude ) > 0 :
214234 logger .info ("Going to remove the following attributes:"
@@ -298,6 +318,61 @@ def retrieve_class_labels(self, target_name='class'):
298318 else :
299319 return None
300320
321+
322+ def get_features_by_type (self , data_type , exclude = None ,
323+ exclude_ignore_attributes = True ,
324+ exclude_row_id_attribute = True ):
325+ '''
326+ Returns indices of features of a given type, e.g., all nominal features.
327+ Can use additional parameters to exclude various features by index or ontology.
328+
329+ Parameters
330+ ----------
331+ data_type : str
332+ The data type to return (e.g., nominal, numeric, date, string)
333+ exclude : list(int)
334+ Indices to exclude (and adapt the return values as if these indices
335+ are not present)
336+ exclude_ignore_attributes : bool
337+ Whether to exclude the defined ignore attributes (and adapt the
338+ return values as if these indices are not present)
339+ exclude_row_id_attribute : bool
340+ Whether to exclude the defined row id attributes (and adapt the
341+ return values as if these indices are not present)
342+
343+ Returns
344+ -------
345+ result : list
346+ a list of indices that have the specified data type
347+ '''
348+ assert data_type in OpenMLDataFeature .LEGAL_DATA_TYPES , "Illegal feature type requested"
349+ if self .ignore_attributes is not None :
350+ assert type (self .ignore_attributes ) is list , "ignore_attributes should be a list"
351+ if self .row_id_attribute is not None :
352+ assert type (self .row_id_attribute ) is str , "row id attribute should be a str"
353+ if exclude is not None :
354+ assert type (exclude ) is list , "Exclude should be a list"
355+ # assert all(isinstance(elem, str) for elem in exclude), "Exclude should be a list of strings"
356+ to_exclude = []
357+ if exclude is not None :
358+ to_exclude .extend (exclude )
359+ if exclude_ignore_attributes and self .ignore_attributes is not None :
360+ to_exclude .extend (self .ignore_attributes )
361+ if exclude_row_id_attribute and self .row_id_attribute is not None :
362+ to_exclude .append (self .row_id_attribute )
363+
364+ result = []
365+ offset = 0
366+ # this function assumes that everything in to_exclude will be 'excluded' from the dataset (hence the offset)
367+ for idx in self .features :
368+ name = self .features [idx ].name
369+ if name in to_exclude :
370+ offset += 1
371+ else :
372+ if self .features [idx ].data_type == data_type :
373+ result .append (idx - offset )
374+ return result
375+
301376 def publish (self ):
302377 """Publish the dataset on the OpenML server.
303378
@@ -349,8 +424,8 @@ def _to_xml(self):
349424
350425 def _data_features_supported (self ):
351426 if self .features is not None :
352- for feature in self .features [ 'oml:feature' ] :
353- if feature [ 'oml:data_type' ] not in ['numeric' , 'nominal' ]:
427+ for idx in self .features :
428+ if self . features [ idx ]. data_type not in ['numeric' , 'nominal' ]:
354429 return False
355430 return True
356431 return True
0 commit comments