Skip to content

Commit c6f85b6

Browse files
committed
remove argument which value can be inferred from data
1 parent 4181c4a commit c6f85b6

4 files changed

Lines changed: 26 additions & 34 deletions

File tree

examples/OpenML_Tutorial.ipynb

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,6 @@
841841
"source": [
842842
"X, y, attribute_names = dataset.get_data(\n",
843843
" target=dataset.default_target_attribute,\n",
844-
" target_dtype=int,\n",
845844
" return_attribute_names=True,\n",
846845
")\n",
847846
"eeg = pd.DataFrame(X, columns=attribute_names)\n",
@@ -932,10 +931,7 @@
932931
"from sklearn import neighbors\n",
933932
"\n",
934933
"dataset = oml.datasets.get_dataset(1471)\n",
935-
"X, y = dataset.get_data(\n",
936-
" target=dataset.default_target_attribute,\n",
937-
" target_dtype=int,\n",
938-
")\n",
934+
"X, y = dataset.get_data(target=dataset.default_target_attribute)\n",
939935
"clf = neighbors.KNeighborsClassifier(n_neighbors=1)\n",
940936
"clf.fit(X, y)"
941937
]
@@ -992,8 +988,8 @@
992988
"dataset = oml.datasets.get_dataset(10)\n",
993989
"X, y, categorical = dataset.get_data(\n",
994990
" target=dataset.default_target_attribute,\n",
995-
" target_dtype=int,\n",
996-
" return_categorical_indicator=True)\n",
991+
" return_categorical_indicator=True,\n",
992+
")\n",
997993
"print(\"Categorical features: %s\" % categorical)\n",
998994
"enc = preprocessing.OneHotEncoder(categorical_features=categorical)\n",
999995
"X = enc.fit_transform(X)\n",

openml/datasets/dataset.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,12 @@ def decode_arff(fh):
184184
with io.open(filename, encoding='utf8') as fh:
185185
return decode_arff(fh)
186186

187-
def get_data(self, target=None, target_dtype=None, include_row_id=False,
187+
def get_data(self, target=None,
188+
include_row_id=False,
188189
include_ignore_attributes=False,
189190
return_categorical_indicator=False,
190-
return_attribute_names=False):
191+
return_attribute_names=False
192+
):
191193
"""Returns dataset content as numpy arrays / sparse matrices.
192194
193195
Parameters
@@ -225,7 +227,10 @@ def get_data(self, target=None, target_dtype=None, include_row_id=False,
225227
if not self.ignore_attributes:
226228
pass
227229
else:
228-
to_exclude.extend(self.ignore_attributes)
230+
if isinstance(self.ignore_attributes, six.string_types):
231+
to_exclude.append(self.ignore_attributes)
232+
else:
233+
to_exclude.extend(self.ignore_attributes)
229234

230235
if len(to_exclude) > 0:
231236
logger.info("Going to remove the following attributes:"
@@ -242,14 +247,19 @@ def get_data(self, target=None, target_dtype=None, include_row_id=False,
242247
else:
243248
if isinstance(target, six.string_types):
244249
target = [target]
245-
legal_target_types = (int, float, np.float32, np.float64)
246-
if target_dtype not in legal_target_types:
247-
raise ValueError(
248-
"%s is not a legal target type. Legal target types are %s" %
249-
(target_dtype, legal_target_types)
250-
)
251250
targets = np.array([True if column in target else False
252251
for column in attribute_names])
252+
if np.sum(targets) > 1:
253+
raise NotImplementedError(
254+
"Number of requested targets %d is not implemented." %
255+
np.sum(targets)
256+
)
257+
target_categorical = [
258+
cat for cat, column in
259+
six.moves.zip(categorical, attribute_names)
260+
if column in target
261+
]
262+
target_dtype = int if target_categorical[0] else float
253263

254264
try:
255265
x = data[:, ~targets]

openml/tasks/task.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,9 @@ def get_X_and_y(self):
4444
4545
"""
4646
dataset = self.get_dataset()
47-
# Replace with retrieve from cache
48-
if self.task_type_id == 1: # Supervised classification
49-
target_dtype = int
50-
elif self.task_type_id == 2: # Supervised regression
51-
target_dtype = float
52-
elif self.task_type_id == 3: # Learning curves task for classification
53-
target_dtype = int
54-
else:
47+
if self.task_type_id not in (1, 2, 3):
5548
raise NotImplementedError(self.task_type)
56-
X_and_y = dataset.get_data(target=self.target_name,
57-
target_dtype=target_dtype)
49+
X_and_y = dataset.get_data(target=self.target_name)
5850
return X_and_y
5951

6052
def get_train_test_split_indices(self, fold=0, repeat=0, sample=0):

tests/test_datasets/test_dataset.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,13 @@ def test_get_data_with_rowid(self):
4747
self.assertEqual(len(categorical), 38)
4848

4949
def test_get_data_with_target(self):
50-
X, y = self.dataset.get_data(target="class", target_dtype=int)
50+
X, y = self.dataset.get_data(target="class")
5151
self.assertIsInstance(X, np.ndarray)
5252
self.assertEqual(X.dtype, np.float32)
5353
self.assertIn(y.dtype, [np.int32, np.int64])
5454
self.assertEqual(X.shape, (898, 38))
5555
X, y, attribute_names = self.dataset.get_data(
5656
target="class",
57-
target_dtype=int,
5857
return_attribute_names=True
5958
)
6059
self.assertEqual(len(attribute_names), 38)
@@ -66,7 +65,6 @@ def test_get_data_rowid_and_ignore_and_target(self):
6665
self.dataset.row_id_attribute = ["hardness"]
6766
X, y = self.dataset.get_data(
6867
target="class",
69-
target_dtype=int,
7068
include_row_id=False,
7169
include_ignore_attributes=False
7270
)
@@ -75,7 +73,6 @@ def test_get_data_rowid_and_ignore_and_target(self):
7573
self.assertEqual(X.shape, (898, 36))
7674
X, y, categorical = self.dataset.get_data(
7775
target="class",
78-
target_dtype=int,
7976
return_categorical_indicator=True,
8077
)
8178
self.assertEqual(len(categorical), 36)
@@ -110,15 +107,14 @@ def setUp(self):
110107
self.sparse_dataset = openml.datasets.get_dataset(4136)
111108

112109
def test_get_sparse_dataset_with_target(self):
113-
X, y = self.sparse_dataset.get_data(target="class", target_dtype=int)
110+
X, y = self.sparse_dataset.get_data(target="class")
114111
self.assertTrue(sparse.issparse(X))
115112
self.assertEqual(X.dtype, np.float32)
116113
self.assertIsInstance(y, np.ndarray)
117114
self.assertIn(y.dtype, [np.int32, np.int64])
118115
self.assertEqual(X.shape, (600, 20000))
119116
X, y, attribute_names = self.sparse_dataset.get_data(
120117
target="class",
121-
target_dtype=int,
122118
return_attribute_names=True,
123119
)
124120
self.assertTrue(sparse.issparse(X))
@@ -184,7 +180,6 @@ def test_get_sparse_dataset_rowid_and_ignore_and_target(self):
184180
self.sparse_dataset.row_id_attribute = ["V512"]
185181
X, y = self.sparse_dataset.get_data(
186182
target="class",
187-
target_dtype=int,
188183
include_row_id=False,
189184
include_ignore_attributes=False,
190185
)
@@ -194,7 +189,6 @@ def test_get_sparse_dataset_rowid_and_ignore_and_target(self):
194189
self.assertEqual(X.shape, (600, 19998))
195190
X, y, categorical = self.sparse_dataset.get_data(
196191
target="class",
197-
target_dtype=int,
198192
return_categorical_indicator=True,
199193
)
200194
self.assertTrue(sparse.issparse(X))

0 commit comments

Comments
 (0)