33import pickle
44
55import numpy as np
6- import scipy . io . arff
6+ import arff
77
88
99Split = namedtuple ("Split" , ["train" , "test" ])
@@ -77,20 +77,22 @@ def _from_arff_file(cls, filename: str) -> 'OpenMLSplit':
7777 raise FileNotFoundError (
7878 'Split arff %s does not exist!' % filename
7979 )
80- splits , meta = scipy .io .arff .loadarff (filename )
81- name = meta .name
80+ file_data = arff .load (open (filename ), return_type = arff .DENSE_GEN )
81+ splits = file_data ['data' ]
82+ name = file_data ['relation' ]
83+ attrnames = [attr [0 ] for attr in file_data ['attributes' ]]
8284
8385 repetitions = OrderedDict ()
8486
85- type_idx = meta . _attrnames .index ('type' )
86- rowid_idx = meta . _attrnames .index ('rowid' )
87- repeat_idx = meta . _attrnames .index ('repeat' )
88- fold_idx = meta . _attrnames .index ('fold' )
87+ type_idx = attrnames .index ('type' )
88+ rowid_idx = attrnames .index ('rowid' )
89+ repeat_idx = attrnames .index ('repeat' )
90+ fold_idx = attrnames .index ('fold' )
8991 sample_idx = (
90- meta . _attrnames .index ('sample' )
91- if 'sample' in meta . _attrnames
92+ attrnames .index ('sample' )
93+ if 'sample' in attrnames
9294 else None
93- ) # can be None
95+ )
9496
9597 for line in splits :
9698 # A line looks like type, rowid, repeat, fold
@@ -108,7 +110,7 @@ def _from_arff_file(cls, filename: str) -> 'OpenMLSplit':
108110 repetitions [repetition ][fold ][sample ] = ([], [])
109111 split = repetitions [repetition ][fold ][sample ]
110112
111- type_ = line [type_idx ]. decode ( 'utf-8' )
113+ type_ = line [type_idx ]
112114 if type_ == 'TRAIN' :
113115 split [0 ].append (line [rowid_idx ])
114116 elif type_ == 'TEST' :
0 commit comments