Skip to content

Commit cfa5f2a

Browse files
committed
Merge branch 'develop' of github.com:openml/openml-python into task_tagging
2 parents 9c66b06 + 1fff169 commit cfa5f2a

7 files changed

Lines changed: 253 additions & 368 deletions

File tree

doc/usage.rst

Lines changed: 159 additions & 327 deletions
Large diffs are not rendered by default.

examples/OpenML_Tutorial.ipynb

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@
2424
},
2525
{
2626
"cell_type": "raw",
27-
"metadata": {
28-
"collapsed": true
29-
},
27+
"metadata": {},
3028
"source": [
3129
"# Install OpenML (developer version)\n",
3230
"# 'pip install openml' coming up (october 2017) \n",
@@ -842,8 +840,9 @@
842840
],
843841
"source": [
844842
"X, y, attribute_names = dataset.get_data(\n",
845-
" target=dataset.default_target_attribute, \n",
846-
" return_attribute_names=True)\n",
843+
" target=dataset.default_target_attribute,\n",
844+
" return_attribute_names=True,\n",
845+
")\n",
847846
"eeg = pd.DataFrame(X, columns=attribute_names)\n",
848847
"eeg['class'] = y\n",
849848
"print(eeg[:10])"
@@ -989,7 +988,8 @@
989988
"dataset = oml.datasets.get_dataset(10)\n",
990989
"X, y, categorical = dataset.get_data(\n",
991990
" target=dataset.default_target_attribute,\n",
992-
" return_categorical_indicator=True)\n",
991+
" return_categorical_indicator=True,\n",
992+
")\n",
993993
"print(\"Categorical features: %s\" % categorical)\n",
994994
"enc = preprocessing.OneHotEncoder(categorical_features=categorical)\n",
995995
"X = enc.fit_transform(X)\n",
@@ -1547,7 +1547,7 @@
15471547
"name": "python",
15481548
"nbconvert_exporter": "python",
15491549
"pygments_lexer": "ipython3",
1550-
"version": "3.6.0"
1550+
"version": "3.6.1"
15511551
}
15521552
},
15531553
"nbformat": 4,

openml/datasets/dataset.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,7 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
8787
raise ValueError('Data features not provided in right order')
8888
self.features[feature.index] = feature
8989

90-
if qualities is not None:
91-
self.qualities = {}
92-
for idx, xmlquality in enumerate(qualities['oml:quality']):
93-
name = xmlquality['oml:name']
94-
value = xmlquality['oml:value']
95-
self.qualities[name] = value
90+
self.qualities = _check_qualities(qualities)
9691

9792
if data_file is not None:
9893
if self._data_features_supported():
@@ -205,10 +200,12 @@ def decode_arff(fh):
205200
with io.open(filename, encoding='utf8') as fh:
206201
return decode_arff(fh)
207202

208-
def get_data(self, target=None, target_dtype=int, include_row_id=False,
203+
def get_data(self, target=None,
204+
include_row_id=False,
209205
include_ignore_attributes=False,
210206
return_categorical_indicator=False,
211-
return_attribute_names=False):
207+
return_attribute_names=False
208+
):
212209
"""Returns dataset content as numpy arrays / sparse matrices.
213210
214211
Parameters
@@ -246,7 +243,10 @@ def get_data(self, target=None, target_dtype=int, include_row_id=False,
246243
if not self.ignore_attributes:
247244
pass
248245
else:
249-
to_exclude.extend(self.ignore_attributes)
246+
if isinstance(self.ignore_attributes, six.string_types):
247+
to_exclude.append(self.ignore_attributes)
248+
else:
249+
to_exclude.extend(self.ignore_attributes)
250250

251251
if len(to_exclude) > 0:
252252
logger.info("Going to remove the following attributes:"
@@ -265,6 +265,17 @@ def get_data(self, target=None, target_dtype=int, include_row_id=False,
265265
target = [target]
266266
targets = np.array([True if column in target else False
267267
for column in attribute_names])
268+
if np.sum(targets) > 1:
269+
raise NotImplementedError(
270+
"Number of requested targets %d is not implemented." %
271+
np.sum(targets)
272+
)
273+
target_categorical = [
274+
cat for cat, column in
275+
six.moves.zip(categorical, attribute_names)
276+
if column in target
277+
]
278+
target_dtype = int if target_categorical[0] else float
268279

269280
try:
270281
x = data[:, ~targets]
@@ -442,3 +453,21 @@ def _data_features_supported(self):
442453
return False
443454
return True
444455
return True
456+
457+
458+
459+
def _check_qualities(qualities):
460+
if qualities is not None:
461+
qualities_ = {}
462+
for xmlquality in qualities:
463+
name = xmlquality['oml:name']
464+
if xmlquality['oml:value'] is None:
465+
value = float('NaN')
466+
elif xmlquality['oml:value'] == 'null':
467+
value = float('NaN')
468+
else:
469+
value = float(xmlquality['oml:value'])
470+
qualities_[name] = value
471+
return qualities_
472+
else:
473+
return None

openml/datasets/functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _get_cached_dataset_qualities(dataset_id):
116116
try:
117117
with io.open(qualities_file, encoding='utf8') as fh:
118118
qualities_xml = fh.read()
119-
return xmltodict.parse(qualities_xml)["oml:data_qualities"]
119+
return xmltodict.parse(qualities_xml)["oml:data_qualities"]['oml:quality']
120120
except (IOError, OSError):
121121
raise OpenMLCacheException("Dataset qualities for dataset id %d not "
122122
"cached" % dataset_id)
@@ -454,7 +454,7 @@ def _get_dataset_qualities(did_cache_dir, dataset_id):
454454
with io.open(qualities_file, "w", encoding='utf8') as fh:
455455
fh.write(qualities_xml)
456456

457-
qualities = xmltodict.parse(qualities_xml, force_list=('oml:quality',))['oml:data_qualities']
457+
qualities = xmltodict.parse(qualities_xml, force_list=('oml:quality',))['oml:data_qualities']['oml:quality']
458458

459459
return qualities
460460

openml/tasks/task.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,21 +36,17 @@ def get_dataset(self):
3636
return datasets.get_dataset(self.dataset_id)
3737

3838
def get_X_and_y(self):
39+
"""Get data associated with the current task.
40+
41+
Returns
42+
-------
43+
tuple - X and y
44+
45+
"""
3946
dataset = self.get_dataset()
40-
# Replace with retrieve from cache
41-
if self.task_type_id == 1:
42-
# if 'Supervised Classification'.lower() in self.task_type.lower():
43-
target_dtype = int
44-
# elif 'Supervised Regression'.lower() in self.task_type.lower():
45-
elif self.task_type_id == 2:
46-
target_dtype = float
47-
# elif ''.lower('Learning Curve') in self.task_type.lower():
48-
elif self.task_type_id == 3:
49-
target_dtype = int
50-
else:
47+
if self.task_type_id not in (1, 2, 3):
5148
raise NotImplementedError(self.task_type)
52-
X_and_y = dataset.get_data(target=self.target_name,
53-
target_dtype=target_dtype)
49+
X_and_y = dataset.get_data(target=self.target_name)
5450
return X_and_y
5551

5652
def get_train_test_split_indices(self, fold=0, repeat=0, sample=0):

tests/test_datasets/test_dataset.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,21 +54,28 @@ def test_get_data_with_target(self):
5454
self.assertIn(y.dtype, [np.int32, np.int64])
5555
self.assertEqual(X.shape, (898, 38))
5656
X, y, attribute_names = self.dataset.get_data(
57-
target="class", return_attribute_names=True)
57+
target="class",
58+
return_attribute_names=True
59+
)
5860
self.assertEqual(len(attribute_names), 38)
5961
self.assertNotIn("class", attribute_names)
6062
self.assertEqual(y.shape, (898, ))
6163

6264
def test_get_data_rowid_and_ignore_and_target(self):
6365
self.dataset.ignore_attributes = ["condition"]
6466
self.dataset.row_id_attribute = ["hardness"]
65-
X, y = self.dataset.get_data(target="class", include_row_id=False,
66-
include_ignore_attributes=False)
67+
X, y = self.dataset.get_data(
68+
target="class",
69+
include_row_id=False,
70+
include_ignore_attributes=False
71+
)
6772
self.assertEqual(X.dtype, np.float32)
6873
self.assertIn(y.dtype, [np.int32, np.int64])
6974
self.assertEqual(X.shape, (898, 36))
7075
X, y, categorical = self.dataset.get_data(
71-
target="class", return_categorical_indicator=True)
76+
target="class",
77+
return_categorical_indicator=True,
78+
)
7279
self.assertEqual(len(categorical), 36)
7380
self.assertListEqual(categorical, [True] * 3 + [False] + [True] * 2 + [
7481
False] + [True] * 23 + [False] * 3 + [True] * 3)
@@ -127,7 +134,9 @@ def test_get_sparse_dataset_with_target(self):
127134
self.assertIn(y.dtype, [np.int32, np.int64])
128135
self.assertEqual(X.shape, (600, 20000))
129136
X, y, attribute_names = self.sparse_dataset.get_data(
130-
target="class", return_attribute_names=True)
137+
target="class",
138+
return_attribute_names=True,
139+
)
131140
self.assertTrue(sparse.issparse(X))
132141
self.assertEqual(len(attribute_names), 20000)
133142
self.assertNotIn("class", attribute_names)
@@ -190,15 +199,34 @@ def test_get_sparse_dataset_rowid_and_ignore_and_target(self):
190199
self.sparse_dataset.ignore_attributes = ["V256"]
191200
self.sparse_dataset.row_id_attribute = ["V512"]
192201
X, y = self.sparse_dataset.get_data(
193-
target="class", include_row_id=False,
194-
include_ignore_attributes=False)
202+
target="class",
203+
include_row_id=False,
204+
include_ignore_attributes=False,
205+
)
195206
self.assertTrue(sparse.issparse(X))
196207
self.assertEqual(X.dtype, np.float32)
197208
self.assertIn(y.dtype, [np.int32, np.int64])
198209
self.assertEqual(X.shape, (600, 19998))
199210
X, y, categorical = self.sparse_dataset.get_data(
200-
target="class", return_categorical_indicator=True)
211+
target="class",
212+
return_categorical_indicator=True,
213+
)
201214
self.assertTrue(sparse.issparse(X))
202215
self.assertEqual(len(categorical), 19998)
203216
self.assertListEqual(categorical, [False] * 19998)
204217
self.assertEqual(y.shape, (600, ))
218+
219+
220+
class OpenMLDatasetQualityTest(TestBase):
221+
def test__check_qualities(self):
222+
qualities = [{'oml:name': 'a', 'oml:value': '0.5'}]
223+
qualities = openml.datasets.dataset._check_qualities(qualities)
224+
self.assertEqual(qualities['a'], 0.5)
225+
226+
qualities = [{'oml:name': 'a', 'oml:value': 'null'}]
227+
qualities = openml.datasets.dataset._check_qualities(qualities)
228+
self.assertNotEqual(qualities['a'], qualities['a'])
229+
230+
qualities = [{'oml:name': 'a', 'oml:value': None}]
231+
qualities = openml.datasets.dataset._check_qualities(qualities)
232+
self.assertNotEqual(qualities['a'], qualities['a'])

tests/test_datasets/test_dataset_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test__get_cached_dataset(self, ):
7373
self.assertIsInstance(dataset, OpenMLDataset)
7474
self.assertTrue(len(dataset.features) > 0)
7575
self.assertTrue(len(dataset.features) == len(features['oml:feature']))
76-
self.assertTrue(len(dataset.qualities) == len(qualities['oml:quality']))
76+
self.assertTrue(len(dataset.qualities) == len(qualities))
7777

7878
def test_get_cached_dataset_description(self):
7979
openml.config.set_cache_directory(self.static_cache_dir)
@@ -209,7 +209,7 @@ def test__get_dataset_features(self):
209209
def test__get_dataset_qualities(self):
210210
# Only a smoke check
211211
qualities = _get_dataset_qualities(self.workdir, 2)
212-
self.assertIsInstance(qualities, dict)
212+
self.assertIsInstance(qualities, list)
213213

214214
def test_deletion_of_cache_dir(self):
215215
# Simple removal

0 commit comments

Comments
 (0)