Skip to content

Commit 4257c48

Browse files
Neeratyoymfeurer
authored andcommitted
Removing dependency on scipy.io.arff (#693)
* Removing dependency on scipy arff * Cleaning code * Loading arff as generator object * Removing redundant decode * PEP8
1 parent eec86a9 commit 4257c48

1 file changed

Lines changed: 13 additions & 11 deletions

File tree

openml/tasks/split.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pickle
44

55
import numpy as np
6-
import scipy.io.arff
6+
import arff
77

88

99
Split = namedtuple("Split", ["train", "test"])
@@ -77,20 +77,22 @@ def _from_arff_file(cls, filename: str) -> 'OpenMLSplit':
7777
raise FileNotFoundError(
7878
'Split arff %s does not exist!' % filename
7979
)
80-
splits, meta = scipy.io.arff.loadarff(filename)
81-
name = meta.name
80+
file_data = arff.load(open(filename), return_type=arff.DENSE_GEN)
81+
splits = file_data['data']
82+
name = file_data['relation']
83+
attrnames = [attr[0] for attr in file_data['attributes']]
8284

8385
repetitions = OrderedDict()
8486

85-
type_idx = meta._attrnames.index('type')
86-
rowid_idx = meta._attrnames.index('rowid')
87-
repeat_idx = meta._attrnames.index('repeat')
88-
fold_idx = meta._attrnames.index('fold')
87+
type_idx = attrnames.index('type')
88+
rowid_idx = attrnames.index('rowid')
89+
repeat_idx = attrnames.index('repeat')
90+
fold_idx = attrnames.index('fold')
8991
sample_idx = (
90-
meta._attrnames.index('sample')
91-
if 'sample' in meta._attrnames
92+
attrnames.index('sample')
93+
if 'sample' in attrnames
9294
else None
93-
) # can be None
95+
)
9496

9597
for line in splits:
9698
# A line looks like type, rowid, repeat, fold
@@ -108,7 +110,7 @@ def _from_arff_file(cls, filename: str) -> 'OpenMLSplit':
108110
repetitions[repetition][fold][sample] = ([], [])
109111
split = repetitions[repetition][fold][sample]
110112

111-
type_ = line[type_idx].decode('utf-8')
113+
type_ = line[type_idx]
112114
if type_ == 'TRAIN':
113115
split[0].append(line[rowid_idx])
114116
elif type_ == 'TEST':

0 commit comments

Comments
 (0)