Skip to content

Commit 72f131a

Browse files
PGijsbersmfeurer
authored andcommitted
[MRG] Fix402 (#677)
* Make more explicit splitting. * Always return four values. * Update function signature. Update dataformat to expected 0.9 behavior. * Stashing changes. WIP update tests. * PEP8 says not to test boolean values with 'is'. * Fix ignore_row_attribute test. * Streamline if-else flow for excluding attributes. * Update doc to reflect multiple targets is not supported. * Updated all tests. * Updated other calls. * Fix sparse tests. * Flake8. * Feedback mfeurer. * Parameter not Optional.
1 parent 813daeb commit 72f131a

6 files changed

Lines changed: 163 additions & 223 deletions

File tree

examples/datasets_tutorial.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,9 @@
6060
# controlled with the parameter ``dataset_format`` which can be either 'array'
6161
# (default) or 'dataframe'. Let's first build our dataset from a NumPy array
6262
# and manually create a dataframe.
63-
X, y, attribute_names = dataset.get_data(
63+
X, y, categorical_indicator, attribute_names = dataset.get_data(
6464
dataset_format='array',
65-
target=dataset.default_target_attribute,
66-
return_attribute_names=True,
65+
target=dataset.default_target_attribute
6766
)
6867
eeg = pd.DataFrame(X, columns=attribute_names)
6968
eeg['class'] = y
@@ -72,8 +71,10 @@
7271
############################################################################
7372
# Instead of manually creating the dataframe, you can already request a
7473
# dataframe with the correct dtypes.
75-
X, y = dataset.get_data(target=dataset.default_target_attribute,
76-
dataset_format='dataframe')
74+
X, y, categorical_indicator, attribute_names = dataset.get_data(
75+
target=dataset.default_target_attribute,
76+
dataset_format='dataframe'
77+
)
7778
print(X.head())
7879
print(X.info())
7980

examples/flows_and_runs_tutorial.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
openml.config.start_using_configuration_for_example()
2323
# NOTE: We are using dataset 68 from the test server: https://test.openml.org/d/68
2424
dataset = openml.datasets.get_dataset(68)
25-
X, y = dataset.get_data(
25+
X, y, categorical_indicator, attribute_names = dataset.get_data(
2626
dataset_format='array',
2727
target=dataset.default_target_attribute
2828
)
@@ -34,13 +34,12 @@
3434
#
3535
# * e.g. categorical features -> do feature encoding
3636
dataset = openml.datasets.get_dataset(17)
37-
X, y, categorical = dataset.get_data(
37+
X, y, categorical_indicator, attribute_names = dataset.get_data(
3838
dataset_format='array',
39-
target=dataset.default_target_attribute,
40-
return_categorical_indicator=True,
39+
target=dataset.default_target_attribute
4140
)
42-
print("Categorical features: %s" % categorical)
43-
enc = preprocessing.OneHotEncoder(categorical_features=categorical)
41+
print("Categorical features: {}".format(categorical_indicator))
42+
enc = preprocessing.OneHotEncoder(categorical_features=categorical_indicator)
4443
X = enc.fit_transform(X)
4544
clf.fit(X, y)
4645

openml/datasets/dataset.py

Lines changed: 34 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import os
66
import pickle
7-
from typing import List, Optional, Union
7+
from typing import List, Optional, Union, Tuple, Iterable
88

99
import arff
1010
import numpy as np
@@ -419,29 +419,31 @@ def _download_data(self) -> None:
419419
from .functions import _get_dataset_arff
420420
self.data_file = _get_dataset_arff(self)
421421

422-
def get_data(self, target: Optional[Union[List[str], str]] = None,
423-
include_row_id: bool = False,
424-
include_ignore_attributes: bool = False,
425-
return_categorical_indicator: bool = False,
426-
return_attribute_names: bool = False,
427-
dataset_format: str = None):
422+
def get_data(
423+
self,
424+
target: Optional[Union[List[str], str]] = None,
425+
include_row_id: bool = False,
426+
include_ignore_attributes: bool = False,
427+
dataset_format: str = "dataframe",
428+
) -> Tuple[
429+
Union[np.ndarray, pd.DataFrame, scipy.sparse.csr_matrix],
430+
Optional[Union[np.ndarray, pd.DataFrame]],
431+
List[bool],
432+
List[str]
433+
]:
428434
""" Returns dataset content as dataframes or sparse matrices.
429435
430436
Parameters
431437
----------
432-
target : string, list of strings or None (default=None)
433-
Name of target column(s) to separate from the data.
438+
target : string, List[str] or None (default=None)
439+
Name of target column to separate from the data.
440+
Splitting multiple columns is currently not supported.
434441
include_row_id : boolean (default=False)
435442
Whether to include row ids in the returned dataset.
436443
include_ignore_attributes : boolean (default=False)
437444
Whether to include columns that are marked as "ignore"
438445
on the server in the dataset.
439-
return_categorical_indicator : boolean (default=False)
440-
Whether to return a boolean mask indicating which features are
441-
categorical.
442-
return_attribute_names : boolean (default=False)
443-
Whether to return attribute names.
444-
dataset_format : string, optional
446+
dataset_format : string (default='dataframe')
445447
The format of returned dataset.
446448
If ``array``, the returned dataset will be a NumPy array or a SciPy sparse matrix.
447449
If ``dataframe``, the returned dataset will be a Pandas DataFrame or SparseDataFrame.
@@ -450,22 +452,13 @@ def get_data(self, target: Optional[Union[List[str], str]] = None,
450452
-------
451453
X : ndarray, dataframe, or sparse matrix, shape (n_samples, n_columns)
452454
Dataset
453-
y : ndarray or series, shape (n_samples,)
454-
Target column(s). Only returned if target is not None.
455+
y : ndarray or pd.Series, shape (n_samples, ) or None
456+
Target column
455457
categorical_indicator : boolean ndarray
456458
Mask that indicate categorical features.
457-
Only returned if return_categorical_indicator is True.
458-
return_attribute_names : list of strings
459+
attribute_names : List[str]
459460
List of attribute names.
460-
Only returned if return_attribute_names is True.
461461
"""
462-
if dataset_format is None:
463-
warn('The default of "dataset_format" will change from "array" to'
464-
' "dataframe" in 0.9', FutureWarning)
465-
dataset_format = 'array'
466-
467-
rval = []
468-
469462
if self.data_pickle_file is None:
470463
if self.data_file is None:
471464
self._download_data()
@@ -480,23 +473,17 @@ def get_data(self, target: Optional[Union[List[str], str]] = None,
480473
data, categorical, attribute_names = pickle.load(fh)
481474

482475
to_exclude = []
483-
if include_row_id is False:
484-
if not self.row_id_attribute:
485-
pass
486-
else:
487-
if isinstance(self.row_id_attribute, str):
488-
to_exclude.append(self.row_id_attribute)
489-
else:
490-
to_exclude.extend(self.row_id_attribute)
491-
492-
if include_ignore_attributes is False:
493-
if not self.ignore_attributes:
494-
pass
495-
else:
496-
if isinstance(self.ignore_attributes, str):
497-
to_exclude.append(self.ignore_attributes)
498-
else:
499-
to_exclude.extend(self.ignore_attributes)
476+
if not include_row_id and self.row_id_attribute is not None:
477+
if isinstance(self.row_id_attribute, str):
478+
to_exclude.append(self.row_id_attribute)
479+
elif isinstance(self.row_id_attribute, Iterable):
480+
to_exclude.extend(self.row_id_attribute)
481+
482+
if not include_ignore_attributes and self.ignore_attributes is not None:
483+
if isinstance(self.ignore_attributes, str):
484+
to_exclude.append(self.ignore_attributes)
485+
elif isinstance(self.ignore_attributes, Iterable):
486+
to_exclude.extend(self.ignore_attributes)
500487

501488
if len(to_exclude) > 0:
502489
logger.info("Going to remove the following attributes:"
@@ -514,7 +501,7 @@ def get_data(self, target: Optional[Union[List[str], str]] = None,
514501
if target is None:
515502
data = self._convert_array_format(data, dataset_format,
516503
attribute_names)
517-
rval.append(data)
504+
targets = None
518505
else:
519506
if isinstance(target, str):
520507
if ',' in target:
@@ -552,19 +539,9 @@ def get_data(self, target: Optional[Union[List[str], str]] = None,
552539
y = y.squeeze()
553540
y = self._convert_array_format(y, dataset_format, attribute_names)
554541
y = y.astype(target_dtype) if dataset_format == 'array' else y
542+
data, targets = x, y
555543

556-
rval.append(x)
557-
rval.append(y)
558-
559-
if return_categorical_indicator:
560-
rval.append(categorical)
561-
if return_attribute_names:
562-
rval.append(attribute_names)
563-
564-
if len(rval) == 1:
565-
return rval[0]
566-
else:
567-
return rval
544+
return data, targets, categorical, attribute_names
568545

569546
def retrieve_class_labels(self, target_name: str = 'class') -> Union[None, List[str]]:
570547
"""Reads the datasets arff to determine the class-labels.

openml/tasks/task.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,10 @@ def __init__(
230230
def get_X_and_y(
231231
self,
232232
dataset_format: str = 'array',
233-
) -> Union[np.ndarray, pd.DataFrame, scipy.sparse.spmatrix]:
233+
) -> Tuple[
234+
Union[np.ndarray, pd.DataFrame, scipy.sparse.spmatrix],
235+
Union[np.ndarray, pd.Series]
236+
]:
234237
"""Get data associated with the current task.
235238
236239
Parameters
@@ -247,10 +250,10 @@ def get_X_and_y(
247250
dataset = self.get_dataset()
248251
if self.task_type_id not in (1, 2, 3):
249252
raise NotImplementedError(self.task_type)
250-
X_and_y = dataset.get_data(
253+
X, y, _, _ = dataset.get_data(
251254
dataset_format=dataset_format, target=self.target_name,
252255
)
253-
return X_and_y
256+
return X, y
254257

255258
def _to_dict(self) -> 'OrderedDict[str, OrderedDict]':
256259

@@ -393,10 +396,10 @@ def get_X(
393396
394397
"""
395398
dataset = self.get_dataset()
396-
X_and_y = dataset.get_data(
399+
data, *_ = dataset.get_data(
397400
dataset_format=dataset_format, target=None,
398401
)
399-
return X_and_y
402+
return data
400403

401404
def _to_dict(self) -> 'OrderedDict[str, OrderedDict]':
402405

0 commit comments

Comments
 (0)