44import logging
55import os
66import pickle
7- from typing import List , Optional , Union
7+ from typing import List , Optional , Union , Tuple , Iterable
88
99import arff
1010import numpy as np
@@ -419,29 +419,31 @@ def _download_data(self) -> None:
419419 from .functions import _get_dataset_arff
420420 self .data_file = _get_dataset_arff (self )
421421
422- def get_data (self , target : Optional [Union [List [str ], str ]] = None ,
423- include_row_id : bool = False ,
424- include_ignore_attributes : bool = False ,
425- return_categorical_indicator : bool = False ,
426- return_attribute_names : bool = False ,
427- dataset_format : str = None ):
422+ def get_data (
423+ self ,
424+ target : Optional [Union [List [str ], str ]] = None ,
425+ include_row_id : bool = False ,
426+ include_ignore_attributes : bool = False ,
427+ dataset_format : str = "dataframe" ,
428+ ) -> Tuple [
429+ Union [np .ndarray , pd .DataFrame , scipy .sparse .csr_matrix ],
430+ Optional [Union [np .ndarray , pd .DataFrame ]],
431+ List [bool ],
432+ List [str ]
433+ ]:
428434 """ Returns dataset content as dataframes or sparse matrices.
429435
430436 Parameters
431437 ----------
432- target : string, list of strings or None (default=None)
433- Name of target column(s) to separate from the data.
438+ target : string, List[str] or None (default=None)
439+ Name of target column to separate from the data.
440+ Splitting multiple columns is currently not supported.
434441 include_row_id : boolean (default=False)
435442 Whether to include row ids in the returned dataset.
436443 include_ignore_attributes : boolean (default=False)
437444 Whether to include columns that are marked as "ignore"
438445 on the server in the dataset.
439- return_categorical_indicator : boolean (default=False)
440- Whether to return a boolean mask indicating which features are
441- categorical.
442- return_attribute_names : boolean (default=False)
443- Whether to return attribute names.
444- dataset_format : string, optional
446+ dataset_format : string (default='dataframe')
445447 The format of returned dataset.
446448 If ``array``, the returned dataset will be a NumPy array or a SciPy sparse matrix.
447449 If ``dataframe``, the returned dataset will be a Pandas DataFrame or SparseDataFrame.
@@ -450,22 +452,13 @@ def get_data(self, target: Optional[Union[List[str], str]] = None,
450452 -------
451453 X : ndarray, dataframe, or sparse matrix, shape (n_samples, n_columns)
452454 Dataset
453- y : ndarray or series , shape (n_samples,)
454- Target column(s). Only returned if target is not None.
455+ y : ndarray or pd.Series , shape (n_samples, ) or None
456+ Target column
455457 categorical_indicator : boolean ndarray
456458 Mask that indicate categorical features.
457- Only returned if return_categorical_indicator is True.
458- return_attribute_names : list of strings
459+ attribute_names : List[str]
459460 List of attribute names.
460- Only returned if return_attribute_names is True.
461461 """
462- if dataset_format is None :
463- warn ('The default of "dataset_format" will change from "array" to'
464- ' "dataframe" in 0.9' , FutureWarning )
465- dataset_format = 'array'
466-
467- rval = []
468-
469462 if self .data_pickle_file is None :
470463 if self .data_file is None :
471464 self ._download_data ()
@@ -480,23 +473,17 @@ def get_data(self, target: Optional[Union[List[str], str]] = None,
480473 data , categorical , attribute_names = pickle .load (fh )
481474
482475 to_exclude = []
483- if include_row_id is False :
484- if not self .row_id_attribute :
485- pass
486- else :
487- if isinstance (self .row_id_attribute , str ):
488- to_exclude .append (self .row_id_attribute )
489- else :
490- to_exclude .extend (self .row_id_attribute )
491-
492- if include_ignore_attributes is False :
493- if not self .ignore_attributes :
494- pass
495- else :
496- if isinstance (self .ignore_attributes , str ):
497- to_exclude .append (self .ignore_attributes )
498- else :
499- to_exclude .extend (self .ignore_attributes )
476+ if not include_row_id and self .row_id_attribute is not None :
477+ if isinstance (self .row_id_attribute , str ):
478+ to_exclude .append (self .row_id_attribute )
479+ elif isinstance (self .row_id_attribute , Iterable ):
480+ to_exclude .extend (self .row_id_attribute )
481+
482+ if not include_ignore_attributes and self .ignore_attributes is not None :
483+ if isinstance (self .ignore_attributes , str ):
484+ to_exclude .append (self .ignore_attributes )
485+ elif isinstance (self .ignore_attributes , Iterable ):
486+ to_exclude .extend (self .ignore_attributes )
500487
501488 if len (to_exclude ) > 0 :
502489 logger .info ("Going to remove the following attributes:"
@@ -514,7 +501,7 @@ def get_data(self, target: Optional[Union[List[str], str]] = None,
514501 if target is None :
515502 data = self ._convert_array_format (data , dataset_format ,
516503 attribute_names )
517- rval . append ( data )
504+ targets = None
518505 else :
519506 if isinstance (target , str ):
520507 if ',' in target :
@@ -552,19 +539,9 @@ def get_data(self, target: Optional[Union[List[str], str]] = None,
552539 y = y .squeeze ()
553540 y = self ._convert_array_format (y , dataset_format , attribute_names )
554541 y = y .astype (target_dtype ) if dataset_format == 'array' else y
542+ data , targets = x , y
555543
556- rval .append (x )
557- rval .append (y )
558-
559- if return_categorical_indicator :
560- rval .append (categorical )
561- if return_attribute_names :
562- rval .append (attribute_names )
563-
564- if len (rval ) == 1 :
565- return rval [0 ]
566- else :
567- return rval
544+ return data , targets , categorical , attribute_names
568545
569546 def retrieve_class_labels (self , target_name : str = 'class' ) -> Union [None , List [str ]]:
570547 """Reads the datasets arff to determine the class-labels.
0 commit comments