Skip to content

Commit d55a7ad

Browse files
committed
Merge pull request #24 from openml/fix/regression
FIX allow target to be a float
2 parents 93e0535 + b48fa78 commit d55a7ad

2 files changed

Lines changed: 11 additions & 5 deletions

File tree

openml/entities/dataset.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,7 @@ def decode_arff(fh):
122122
return decode_arff(fh)
123123

124124
############################################################################
125-
# pandas related stuff...
126-
def get_dataset(self, target=None, include_row_id=False,
125+
def get_dataset(self, target=None, target_dtype=int, include_row_id=False,
127126
include_ignore_attributes=False,
128127
return_categorical_indicator=False,
129128
return_attribute_names=False):
@@ -176,7 +175,7 @@ def get_dataset(self, target=None, include_row_id=False,
176175

177176
try:
178177
x = data[:,~targets]
179-
y = data[:,targets].astype(np.int32)
178+
y = data[:,targets].astype(target_dtype)
180179

181180
if len(y.shape) == 2 and y.shape[1] == 1:
182181
y = y[:,0]
@@ -191,7 +190,7 @@ def get_dataset(self, target=None, include_row_id=False,
191190
raise e
192191

193192
if scipy.sparse.issparse(y):
194-
y = np.asarray(y.todense()).astype(np.int32).flatten()
193+
y = np.asarray(y.todense()).astype(target_dtype).flatten()
195194

196195
rval.append(x)
197196
rval.append(y)

openml/entities/task.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,14 @@ def get_dataset(self):
4444
def get_X_and_Y(self):
4545
dataset = self.get_dataset()
4646
# Replace with retrieve from cache
47-
X_and_Y = dataset.get_dataset(target=self.target_feature)
47+
if 'Supervised Classification'.lower() in self.task_type.lower():
48+
target_dtype = int
49+
elif 'Supervised Regression'.lower() in self.task_type.lower():
50+
target_dtype = float
51+
else:
52+
raise NotImplementedError(self.task_type)
53+
X_and_Y = dataset.get_dataset(target=self.target_feature,
54+
target_dtype=target_dtype)
4855
return X_and_Y
4956

5057
def evaluate(self, algo):

0 commit comments

Comments
 (0)