Skip to content

Commit e4e385b

Browse files
PGijsbersmfeurer
authored andcommitted
Fix59 (#683)
* Start method description. * Include version in listing. Refactor number parsing. * Towards retrieving by name. * Finalize _name_to_id. * Adapt get_dataset(s). * Address feedback. * Add two unit tests for retrieving by name. Extract shared code to new function. * Unit tests name to id. * Add test get_dataset_by_name * flake8
1 parent c7db122 commit e4e385b

2 files changed

Lines changed: 185 additions & 84 deletions

File tree

openml/datasets/functions.py

Lines changed: 79 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import io
22
import os
33
import re
4-
from typing import List, Dict, Union
4+
from typing import List, Dict, Union, Optional
55

66
import numpy as np
77
import arff
@@ -247,19 +247,20 @@ def __list_datasets(api_call):
247247

248248
datasets = dict()
249249
for dataset_ in datasets_dict['oml:data']['oml:dataset']:
250-
did = int(dataset_['oml:did'])
251-
dataset = {'did': did,
252-
'name': dataset_['oml:name'],
253-
'format': dataset_['oml:format'],
254-
'status': dataset_['oml:status']}
250+
ignore_attributes = ['oml:file_id', 'oml:quality']
251+
dataset = {k.replace('oml:', ''): v
252+
for (k, v) in dataset_.items()
253+
if k not in ignore_attributes}
254+
dataset['did'] = int(dataset['did'])
255+
dataset['version'] = int(dataset['version'])
255256

256257
# The number of qualities can range from 0 to infinity
257258
for quality in dataset_.get('oml:quality', list()):
258-
quality['#text'] = float(quality['#text'])
259-
if abs(int(quality['#text']) - quality['#text']) < 0.0000001:
260-
quality['#text'] = int(quality['#text'])
261-
dataset[quality['@name']] = quality['#text']
262-
datasets[did] = dataset
259+
try:
260+
dataset[quality['@name']] = int(quality['#text'])
261+
except ValueError:
262+
dataset[quality['@name']] = float(quality['#text'])
263+
datasets[dataset['did']] = dataset
263264

264265
return datasets
265266

@@ -298,6 +299,47 @@ def check_datasets_active(dataset_ids: List[int]) -> Dict[int, bool]:
298299
return active
299300

300301

302+
def _name_to_id(
303+
dataset_name: str,
304+
version: Optional[int] = None,
305+
error_if_multiple: bool = False
306+
) -> int:
307+
""" Attempt to find the dataset id of the dataset with the given name.
308+
309+
If multiple datasets with the name exist, and ``error_if_multiple`` is ``False``,
310+
then return the least recent still active dataset.
311+
312+
Raises an error if no dataset with the name is found.
313+
Raises an error if a version is specified but it could not be found.
314+
315+
Parameters
316+
----------
317+
dataset_name : str
318+
The name of the dataset for which to find its id.
319+
version : int
320+
Version to retrieve. If not specified, the oldest active version is returned.
321+
error_if_multiple : bool (default=False)
322+
If `False`, if multiple datasets match, return the least recent active dataset.
323+
If `True`, if multiple datasets match, raise an error.
324+
325+
Returns
326+
-------
327+
int
328+
The id of the dataset.
329+
"""
330+
status = None if version is not None else 'active'
331+
candidates = list_datasets(data_name=dataset_name, status=status, data_version=version)
332+
if error_if_multiple and len(candidates) > 1:
333+
raise ValueError("Multiple active datasets exist with name {}".format(dataset_name))
334+
if len(candidates) == 0:
335+
no_dataset_for_name = "No active datasets exist with name {}".format(dataset_name)
336+
and_version = " and version {}".format(version) if version is not None else ""
337+
raise RuntimeError(no_dataset_for_name + and_version)
338+
339+
# Dataset ids are chronological so we can just sort based on ids (instead of version)
340+
return sorted(candidates)[0]
341+
342+
301343
def get_datasets(
302344
dataset_ids: List[Union[str, int]],
303345
download_data: bool = True,
@@ -309,7 +351,8 @@ def get_datasets(
309351
Parameters
310352
----------
311353
dataset_ids : iterable
312-
Integers or strings representing dataset ids.
354+
Integers or strings representing dataset ids or dataset names.
355+
If dataset names are specified, the least recent still active dataset version is returned.
313356
download_data : bool, optional
314357
If True, also download the data file. Beware that some datasets are large and it might
315358
make the operation noticeably slower. Metadata is also still retrieved.
@@ -328,13 +371,23 @@ def get_datasets(
328371

329372

330373
@openml.utils.thread_safe_if_oslo_installed
331-
def get_dataset(dataset_id: Union[int, str], download_data: bool = True) -> OpenMLDataset:
374+
def get_dataset(
375+
dataset_id: Union[int, str],
376+
download_data: bool = True,
377+
version: int = None,
378+
error_if_multiple: bool = False
379+
) -> OpenMLDataset:
332380
""" Download the OpenML dataset representation, optionally also download actual data file.
333381
334382
This function is thread/multiprocessing safe.
335383
This function uses caching. A check will be performed to determine if the information has
336384
previously been downloaded, and if so be loaded from disk instead of retrieved from the server.
337385
386+
If dataset is retrieved by name, a version may be specified.
387+
If no version is specified and multiple versions of the dataset exist,
388+
the earliest version of the dataset that is still active will be returned.
389+
This scenario will raise an error instead if `exception_if_multiple` is `True`.
390+
338391
Parameters
339392
----------
340393
dataset_id : int or str
@@ -344,16 +397,24 @@ def get_dataset(dataset_id: Union[int, str], download_data: bool = True) -> Open
344397
make the operation noticeably slower. Metadata is also still retrieved.
345398
If False, create the OpenMLDataset and only populate it with the metadata.
346399
The data may later be retrieved through the `OpenMLDataset.get_data` method.
400+
version : int, optional (default=None)
401+
Specifies the version if `dataset_id` is specified by name.
402+
If no version is specified, retrieve the least recent still active version.
403+
error_if_multiple : bool, optional (default=False)
404+
If `True` raise an error if multiple datasets are found with matching criteria.
347405
348406
Returns
349407
-------
350408
dataset : :class:`openml.OpenMLDataset`
351409
The downloaded dataset."""
352-
try:
353-
dataset_id = int(dataset_id)
354-
except (ValueError, TypeError):
355-
raise ValueError("Dataset ID is neither an Integer nor can be "
356-
"cast to an Integer.")
410+
if isinstance(dataset_id, str):
411+
try:
412+
dataset_id = int(dataset_id)
413+
except ValueError:
414+
dataset_id = _name_to_id(dataset_id, version, error_if_multiple) # type: ignore
415+
elif not isinstance(dataset_id, int):
416+
raise TypeError("`dataset_id` must be one of `str` or `int`, not {}."
417+
.format(type(dataset_id)))
357418

358419
did_cache_dir = _create_cache_directory_for_id(
359420
DATASETS_CACHE_DIR_NAME, dataset_id,

tests/test_datasets/test_dataset_functions.py

Lines changed: 106 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -219,70 +219,120 @@ def test_check_datasets_active(self):
219219
)
220220
openml.config.server = self.test_server
221221

222+
def _datasets_retrieved_successfully(self, dids, metadata_only=True):
223+
""" Checks that all files for the given dids have been downloaded.
224+
225+
This includes:
226+
- description
227+
- qualities
228+
- features
229+
- absence of data arff if metadata_only, else it must be present too.
230+
"""
231+
for did in dids:
232+
self.assertTrue(os.path.exists(os.path.join(
233+
openml.config.get_cache_directory(), "datasets", str(did), "description.xml")))
234+
self.assertTrue(os.path.exists(os.path.join(
235+
openml.config.get_cache_directory(), "datasets", str(did), "qualities.xml")))
236+
self.assertTrue(os.path.exists(os.path.join(
237+
openml.config.get_cache_directory(), "datasets", str(did), "features.xml")))
238+
239+
data_assert = self.assertFalse if metadata_only else self.assertTrue
240+
data_assert(os.path.exists(os.path.join(
241+
openml.config.get_cache_directory(), "datasets", str(did), "dataset.arff")))
242+
243+
def test__name_to_id_with_deactivated(self):
244+
""" Check that an activated dataset is returned if an earlier deactivated one exists. """
245+
openml.config.server = self.production_server
246+
# /d/1 was deactivated
247+
self.assertEqual(openml.datasets.functions._name_to_id('anneal'), 2)
248+
openml.config.server = self.test_server
249+
250+
def test__name_to_id_with_multiple_active(self):
251+
""" With multiple active datasets, retrieve the least recent active. """
252+
self.assertEqual(openml.datasets.functions._name_to_id('iris'), 128)
253+
254+
def test__name_to_id_with_version(self):
255+
""" With multiple active datasets, retrieve the least recent active. """
256+
self.assertEqual(openml.datasets.functions._name_to_id('iris', version=3), 151)
257+
258+
def test__name_to_id_with_multiple_active_error(self):
259+
""" With multiple active datasets, retrieve the least recent active. """
260+
self.assertRaisesRegex(
261+
ValueError,
262+
"Multiple active datasets exist with name iris",
263+
openml.datasets.functions._name_to_id,
264+
dataset_name='iris',
265+
error_if_multiple=True
266+
)
267+
268+
def test__name_to_id_name_does_not_exist(self):
269+
""" With multiple active datasets, retrieve the least recent active. """
270+
self.assertRaisesRegex(
271+
RuntimeError,
272+
"No active datasets exist with name does_not_exist",
273+
openml.datasets.functions._name_to_id,
274+
dataset_name='does_not_exist'
275+
)
276+
277+
def test__name_to_id_version_does_not_exist(self):
278+
""" With multiple active datasets, retrieve the least recent active. """
279+
self.assertRaisesRegex(
280+
RuntimeError,
281+
"No active datasets exist with name iris and version 100000",
282+
openml.datasets.functions._name_to_id,
283+
dataset_name='iris',
284+
version=100000
285+
)
286+
287+
def test_get_datasets_by_name(self):
288+
# did 1 and 2 on the test server:
289+
dids = ['anneal', 'kr-vs-kp']
290+
datasets = openml.datasets.get_datasets(dids, download_data=False)
291+
self.assertEqual(len(datasets), 2)
292+
self._datasets_retrieved_successfully([1, 2])
293+
294+
def test_get_datasets_by_mixed(self):
295+
# did 1 and 2 on the test server:
296+
dids = ['anneal', 2]
297+
datasets = openml.datasets.get_datasets(dids, download_data=False)
298+
self.assertEqual(len(datasets), 2)
299+
self._datasets_retrieved_successfully([1, 2])
300+
222301
def test_get_datasets(self):
223302
dids = [1, 2]
224303
datasets = openml.datasets.get_datasets(dids)
225304
self.assertEqual(len(datasets), 2)
226-
self.assertTrue(os.path.exists(os.path.join(
227-
openml.config.get_cache_directory(), "datasets", "1", "description.xml")))
228-
self.assertTrue(os.path.exists(os.path.join(
229-
openml.config.get_cache_directory(), "datasets", "2", "description.xml")))
230-
self.assertTrue(os.path.exists(os.path.join(
231-
openml.config.get_cache_directory(), "datasets", "1", "dataset.arff")))
232-
self.assertTrue(os.path.exists(os.path.join(
233-
openml.config.get_cache_directory(), "datasets", "2", "dataset.arff")))
234-
self.assertTrue(os.path.exists(os.path.join(
235-
openml.config.get_cache_directory(), "datasets", "1", "features.xml")))
236-
self.assertTrue(os.path.exists(os.path.join(
237-
openml.config.get_cache_directory(), "datasets", "2", "features.xml")))
238-
self.assertTrue(os.path.exists(os.path.join(
239-
openml.config.get_cache_directory(), "datasets", "1", "qualities.xml")))
240-
self.assertTrue(os.path.exists(os.path.join(
241-
openml.config.get_cache_directory(), "datasets", "2", "qualities.xml")))
305+
self._datasets_retrieved_successfully([1, 2], metadata_only=False)
242306

243307
def test_get_datasets_lazy(self):
244308
dids = [1, 2]
245309
datasets = openml.datasets.get_datasets(dids, download_data=False)
246310
self.assertEqual(len(datasets), 2)
247-
self.assertTrue(os.path.exists(os.path.join(
248-
openml.config.get_cache_directory(), "datasets", "1", "description.xml")))
249-
self.assertTrue(os.path.exists(os.path.join(
250-
openml.config.get_cache_directory(), "datasets", "2", "description.xml")))
251-
self.assertTrue(os.path.exists(os.path.join(
252-
openml.config.get_cache_directory(), "datasets", "1", "features.xml")))
253-
self.assertTrue(os.path.exists(os.path.join(
254-
openml.config.get_cache_directory(), "datasets", "2", "features.xml")))
255-
self.assertTrue(os.path.exists(os.path.join(
256-
openml.config.get_cache_directory(), "datasets", "1", "qualities.xml")))
257-
self.assertTrue(os.path.exists(os.path.join(
258-
openml.config.get_cache_directory(), "datasets", "2", "qualities.xml")))
259-
260-
self.assertFalse(os.path.exists(os.path.join(
261-
openml.config.get_cache_directory(), "datasets", "1", "dataset.arff")))
262-
self.assertFalse(os.path.exists(os.path.join(
263-
openml.config.get_cache_directory(), "datasets", "2", "dataset.arff")))
311+
self._datasets_retrieved_successfully([1, 2], metadata_only=True)
264312

265313
datasets[0].get_data()
266-
self.assertTrue(os.path.exists(os.path.join(
267-
openml.config.get_cache_directory(), "datasets", "1", "dataset.arff")))
268-
269314
datasets[1].get_data()
270-
self.assertTrue(os.path.exists(os.path.join(
271-
openml.config.get_cache_directory(), "datasets", "2", "dataset.arff")))
315+
self._datasets_retrieved_successfully([1, 2], metadata_only=False)
316+
317+
def test_get_dataset_by_name(self):
318+
dataset = openml.datasets.get_dataset('anneal')
319+
self.assertEqual(type(dataset), OpenMLDataset)
320+
self.assertEqual(dataset.dataset_id, 1)
321+
self._datasets_retrieved_successfully([1], metadata_only=False)
322+
323+
self.assertGreater(len(dataset.features), 1)
324+
self.assertGreater(len(dataset.qualities), 4)
325+
326+
# Issue324 Properly handle private datasets when trying to access them
327+
openml.config.server = self.production_server
328+
self.assertRaises(OpenMLPrivateDatasetError, openml.datasets.get_dataset, 45)
272329

273330
def test_get_dataset(self):
274331
# This is the only non-lazy load to ensure default behaviour works.
275332
dataset = openml.datasets.get_dataset(1)
276333
self.assertEqual(type(dataset), OpenMLDataset)
277334
self.assertEqual(dataset.name, 'anneal')
278-
self.assertTrue(os.path.exists(os.path.join(
279-
openml.config.get_cache_directory(), "datasets", "1", "description.xml")))
280-
self.assertTrue(os.path.exists(os.path.join(
281-
openml.config.get_cache_directory(), "datasets", "1", "dataset.arff")))
282-
self.assertTrue(os.path.exists(os.path.join(
283-
openml.config.get_cache_directory(), "datasets", "1", "features.xml")))
284-
self.assertTrue(os.path.exists(os.path.join(
285-
openml.config.get_cache_directory(), "datasets", "1", "qualities.xml")))
335+
self._datasets_retrieved_successfully([1], metadata_only=False)
286336

287337
self.assertGreater(len(dataset.features), 1)
288338
self.assertGreater(len(dataset.qualities), 4)
@@ -295,22 +345,13 @@ def test_get_dataset_lazy(self):
295345
dataset = openml.datasets.get_dataset(1, download_data=False)
296346
self.assertEqual(type(dataset), OpenMLDataset)
297347
self.assertEqual(dataset.name, 'anneal')
298-
self.assertTrue(os.path.exists(os.path.join(
299-
openml.config.get_cache_directory(), "datasets", "1", "description.xml")))
300-
self.assertTrue(os.path.exists(os.path.join(
301-
openml.config.get_cache_directory(), "datasets", "1", "features.xml")))
302-
self.assertTrue(os.path.exists(os.path.join(
303-
openml.config.get_cache_directory(), "datasets", "1", "qualities.xml")))
304-
305-
self.assertFalse(os.path.exists(os.path.join(
306-
openml.config.get_cache_directory(), "datasets", "1", "dataset.arff")))
348+
self._datasets_retrieved_successfully([1], metadata_only=True)
307349

308350
self.assertGreater(len(dataset.features), 1)
309351
self.assertGreater(len(dataset.qualities), 4)
310352

311353
dataset.get_data()
312-
self.assertTrue(os.path.exists(os.path.join(
313-
openml.config.get_cache_directory(), "datasets", "1", "dataset.arff")))
354+
self._datasets_retrieved_successfully([1], metadata_only=False)
314355

315356
# Issue324 Properly handle private datasets when trying to access them
316357
openml.config.server = self.production_server
@@ -321,27 +362,26 @@ def test_get_dataset_lazy_all_functions(self):
321362
dataset = openml.datasets.get_dataset(1, download_data=False)
322363
# We only tests functions as general integrity is tested by test_get_dataset_lazy
323364

365+
def ensure_absence_of_real_data():
366+
self.assertFalse(os.path.exists(os.path.join(
367+
openml.config.get_cache_directory(), "datasets", "1", "dataset.arff")))
368+
324369
tag = 'test_lazy_tag_%d' % random.randint(1, 1000000)
325370
dataset.push_tag(tag)
326-
self.assertFalse(os.path.exists(os.path.join(
327-
openml.config.get_cache_directory(), "datasets", "1", "dataset.arff")))
371+
ensure_absence_of_real_data()
328372

329373
dataset.remove_tag(tag)
330-
self.assertFalse(os.path.exists(os.path.join(
331-
openml.config.get_cache_directory(), "datasets", "1", "dataset.arff")))
374+
ensure_absence_of_real_data()
332375

333376
nominal_indices = dataset.get_features_by_type('nominal')
334-
self.assertFalse(os.path.exists(os.path.join(
335-
openml.config.get_cache_directory(), "datasets", "1", "dataset.arff")))
336377
correct = [0, 1, 2, 5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
337378
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 35, 36, 37, 38]
338379
self.assertEqual(nominal_indices, correct)
380+
ensure_absence_of_real_data()
339381

340382
classes = dataset.retrieve_class_labels()
341383
self.assertEqual(classes, ['1', '2', '3', '4', '5', 'U'])
342-
343-
self.assertFalse(os.path.exists(os.path.join(
344-
openml.config.get_cache_directory(), "datasets", "1", "dataset.arff")))
384+
ensure_absence_of_real_data()
345385

346386
def test_get_dataset_sparse(self):
347387
dataset = openml.datasets.get_dataset(102, download_data=False)

0 commit comments

Comments
 (0)