Skip to content

Commit a89ec70

Browse files
committed
Fix bug in dataset.py; fix test about split
1 parent 9c6950f commit a89ec70

4 files changed

Lines changed: 7 additions & 6 deletions

File tree

openml/entities/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def get_dataset(self, target=None, include_row_id=False,
134134
raise ValueError("Cannot find a ndarray file for dataset %s at"
135135
"location %s " % (self.name, path))
136136
else:
137-
with open(path) as fh:
137+
with open(path, "rb") as fh:
138138
data, categorical, attribute_names = pickle.load(fh)
139139

140140
to_exclude = []

openml/entities/split.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,11 @@ def __eq__(self, other):
4343
return False
4444
else:
4545
for fold in self.split[repetition]:
46-
if all(self.split[repetition][fold][0] != \
47-
other.split[repetition][fold][0]) and \
48-
all(self.split[repetition][fold][1] != \
49-
other.split[repetition][fold][1]):
46+
if np.all(self.split[repetition][fold].test != \
47+
other.split[repetition][fold].test)\
48+
and \
49+
np.all(self.split[repetition][fold].train
50+
!= other.split[repetition][fold].train):
5051
return False
5152
return True
5253

tests/entities/test_split.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_eq(self):
4343
self.assertNotEqual(split, split2)
4444

4545
split2 = OpenMLSplit.from_arff_file(self.arff_filename)
46-
split2.split[0][0] = (np.zeros((80)), np.zeros((9)))
46+
split2.split[0][0] = Split(np.zeros((80)), np.zeros((9)))
4747
self.assertNotEqual(split, split2)
4848

4949
def test_from_arff_file(self):

tests/entities/tmp.pkl

Whitespace-only changes.

0 commit comments

Comments
 (0)