99
1010import numpy as np
1111import scipy .sparse
12+ from six .moves import cPickle as pickle
1213import xmltodict
1314
1415from .data_feature import OpenMLDataFeature
1516from ..exceptions import PyOpenMLError
16-
17- if sys .version_info [0 ] >= 3 :
18- import pickle
19- else :
20- try :
21- import cPickle as pickle
22- except :
23- import pickle
24-
25-
26- from ..util import is_string
2717from .._api_calls import _perform_api_call
2818
2919logger = logging .getLogger (__name__ )
@@ -49,7 +39,7 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
4939 row_id_attribute = None , ignore_attribute = None ,
5040 version_label = None , citation = None , tag = None , visibility = None ,
5141 original_data_url = None , paper_url = None , update_comment = None ,
52- md5_checksum = None , data_file = None , features = None ):
42+ md5_checksum = None , data_file = None , features = None , qualities = None ):
5343 # Attributes received by querying the RESTful API
5444 self .dataset_id = int (dataset_id ) if dataset_id is not None else None
5545 self .name = name
@@ -84,6 +74,7 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
8474 self .md5_cheksum = md5_checksum
8575 self .data_file = data_file
8676 self .features = None
77+ self .qualities = None
8778
8879 if features is not None :
8980 self .features = {}
@@ -97,6 +88,12 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
9788 raise ValueError ('Data features not provided in right order' )
9889 self .features [feature .index ] = feature
9990
91+ if qualities is not None :
92+ self .qualities = {}
93+ for idx , xmlquality in enumerate (qualities ['oml:quality' ]):
94+ name = xmlquality ['oml:name' ]
95+ value = xmlquality ['oml:value' ]
96+ self .qualities [name ] = value
10097
10198 if data_file is not None :
10299 if self ._data_features_supported ():
@@ -219,7 +216,7 @@ def get_data(self, target=None, target_dtype=int, include_row_id=False,
219216 if not self .row_id_attribute :
220217 pass
221218 else :
222- if is_string (self .row_id_attribute ):
219+ if isinstance (self .row_id_attribute , six . string_types ):
223220 to_exclude .append (self .row_id_attribute )
224221 else :
225222 to_exclude .extend (self .row_id_attribute )
@@ -243,7 +240,7 @@ def get_data(self, target=None, target_dtype=int, include_row_id=False,
243240 if target is None :
244241 rval .append (data )
245242 else :
246- if is_string (target ):
243+ if isinstance (target , six . string_types ):
247244 target = [target ]
248245 targets = np .array ([True if column in target else False
249246 for column in attribute_names ])
0 commit comments