Skip to content

Commit 9ec141c

Browse files
committed
requested changes for pullrequest #213
1 parent f9bf4f2 commit 9ec141c

5 files changed

Lines changed: 43 additions & 45 deletions

File tree

openml/datasets/data_feature.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ class OpenMLDataFeature(object):
66
----------
77
index : int
88
The index of this feature
9-
name : string
9+
name : str
1010
Name of the feature
11-
data_type : string
11+
data_type : str
1212
can be nominal, numeric, string, date (corresponds to arff)
1313
nominal_values : list(str)
1414
list of the possible values, in case of nominal attribute
@@ -17,11 +17,14 @@ class OpenMLDataFeature(object):
1717
LEGAL_DATA_TYPES = ['nominal', 'numeric', 'string', 'date']
1818

1919
def __init__(self, index, name, data_type, nominal_values, number_missing_values):
20-
assert type(index) is int, "Index is of wrong datatype"
21-
assert data_type in self.LEGAL_DATA_TYPES, "data type should be in %s" %str(self.LEGAL_DATA_TYPES)
22-
if nominal_values is not None:
23-
assert type(nominal_values) is list, "Nominal_values is of wrong datatype"
24-
assert type(number_missing_values) is int, "number_missing_values is of wrong datatype"
20+
if type(index) != int:
21+
raise ValueError('Index is of wrong datatype')
22+
if data_type not in self.LEGAL_DATA_TYPES:
23+
raise ValueError('data type should be in %s, found: %s' %(str(self.LEGAL_DATA_TYPES),data_type))
24+
if nominal_values is not None and type(nominal_values) != list:
25+
raise ValueError('Nominal_values is of wrong datatype')
26+
if type(number_missing_values) != int:
27+
raise ValueError('number_missing_values is of wrong datatype')
2528

2629
self.index = index
2730
self.name = str(name)

openml/datasets/dataset.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
6969
self.ignore_attributes = [ignore_attribute]
7070
elif isinstance(ignore_attribute, list):
7171
self.ignore_attributes = ignore_attribute
72+
elif ignore_attribute is None:
73+
pass
74+
else:
75+
raise ValueError('wrong data type for ignore_attribute. Should be list. ')
7276
self.version_label = version_label
7377
self.citation = citation
7478
self.tag = tag
@@ -88,7 +92,8 @@ def __init__(self, dataset_id=None, name=None, version=None, description=None,
8892
xmlfeature['oml:data_type'],
8993
None, #todo add nominal values (currently not in database)
9094
int(xmlfeature['oml:number_of_missing_values']))
91-
assert idx == feature.index, "Data features not provided in right order"
95+
if idx != feature.index:
96+
raise ValueError('Data features not provided in right order')
9297
self.features[feature.index] = feature
9398

9499

@@ -313,7 +318,21 @@ def retrieve_class_labels(self, target_name='class'):
313318
return None
314319

315320

316-
def get_features_by_type(self, data_type, exclude=None, exclude_ignore_attributes=True, exclude_row_id_attribute=True):
321+
def get_features_by_type(self, data_type, exclude=None,
322+
exclude_ignore_attributes=True, exclude_row_id_attribute=True):
323+
'''
324+
Returns indices of features of a given type, e.g., all nominal features.
325+
Can use additional parameters to exclude various features by index or ontology.
326+
327+
:param data_type: The data type to return (e.g., nominal, numeric, date, string)
328+
:param exclude: Indices to exclude (and adapt the return values as if these indices
329+
are not present)
330+
:param exclude_ignore_attributes: Whether to exclude the defined ignore attributes
331+
(and adapt the return values as if these indices are not present)
332+
:param exclude_row_id_attribute:Whether to exclude the defined row id attributes
333+
(and adapt the return values as if these indices are not present)
334+
:return: a list of indices that have the specified data type
335+
'''
317336
assert data_type in OpenMLDataFeature.LEGAL_DATA_TYPES, "Illegal feature type requested"
318337
if self.ignore_attributes is not None:
319338
assert type(self.ignore_attributes) is list, "ignore_attributes should be a list"

openml/runs/functions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,10 @@ def run_task(task, model):
7474
return run
7575

7676

77-
def _prediction_to_row(rep_no, fold_no, row_id, correct_label, predicted_label, predicted_probabilities, class_labels, model_classes_mapping):
78-
"""Complicated util function that turns probability estimates of a classifier for a given instance into the right arff format to upload to openml.
77+
def _prediction_to_row(rep_no, fold_no, row_id, correct_label, predicted_label,
78+
predicted_probabilities, class_labels, model_classes_mapping):
79+
"""Util function that turns probability estimates of a classifier for a given
80+
instance into the right arff format to upload to openml.
7981
8082
Parameters
8183
----------
@@ -90,6 +92,9 @@ def _prediction_to_row(rep_no, fold_no, row_id, correct_label, predicted_label,
9092
predicted_probabilities : array (size=num_classes)
9193
probabilities per class
9294
class_labels : array (size=num_classes)
95+
model_classes_mapping : list
96+
A list of classes the model produced.
97+
Obtained by BaseEstimator.classes_
9398
9499
Returns
95100
-------

tests/test_flows/test_flow.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ def test_publish_flow(self):
156156

157157
def test_semi_legal_flow(self):
158158
# TODO: Test if parameters are set correctly!
159+
# should not throw error as it contains two differentiable forms of Bagging
160+
# i.e., Bagging(Bagging(J48)) and Bagging(J48)
159161
sentinel = get_sentinel()
160162
semi_legal = sklearn.ensemble.BaggingClassifier(
161163
base_estimator=sklearn.ensemble.BaggingClassifier(
@@ -166,6 +168,7 @@ def test_semi_legal_flow(self):
166168
flow.publish()
167169

168170
def test_illegal_flow(self):
171+
# should throw error as it contains two imputers
169172
illegal = sklearn.pipeline.Pipeline(steps=[('imputer1', sklearn.preprocessing.Imputer()),
170173
('imputer2', sklearn.preprocessing.Imputer()),
171174
('classif', sklearn.tree.DecisionTreeClassifier())])

tests/test_runs/test_run_functions.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def test_get_runs_list_by_tag(self):
278278
self.assertGreaterEqual(len(runs), 1)
279279

280280
def test_run_on_dataset_with_missing_labels(self):
281-
from openml.runs.functions import _prediction_to_row
281+
from openml.runs.functions import _run_task_get_arffcontent
282282
from sklearn.tree import DecisionTreeClassifier
283283
from sklearn.preprocessing.imputation import Imputer
284284
task = openml.tasks.get_task(2)
@@ -287,37 +287,5 @@ def test_run_on_dataset_with_missing_labels(self):
287287
model = Pipeline(steps=[('Imputer', Imputer(strategy='median')),
288288
('Estimator', DecisionTreeClassifier())])
289289

290-
X, Y = task.get_X_and_y()
291-
rep_no = 0
292-
# TODO use different iterator to only provide a single iterator (less
293-
# methods, less maintenance, less confusion)
294-
for rep in task.iterate_repeats():
295-
fold_no = 0
296-
for fold in rep:
297-
train_indices, test_indices = fold
298-
trainX = X[train_indices]
299-
trainY = Y[train_indices]
300-
testX = X[test_indices]
301-
testY = Y[test_indices]
302-
303-
model.fit(trainX, trainY)
304-
305-
ProbaY = model.predict_proba(testX)
306-
PredY = model.predict(testX)
307-
308-
missing_label_idx = [3]
309-
310-
for i in range(0, len(test_indices)):
311-
arff_line = _prediction_to_row(rep_no, fold_no, test_indices[i], class_labels[testY[i]], PredY[i],
312-
ProbaY[i], class_labels, model.classes_)
313-
314-
offset = 0
315-
for idx, proba in enumerate(arff_line[3:-2]):
316-
if idx in missing_label_idx:
317-
offset += 1
318-
else:
319-
assert proba == ProbaY[i][idx-offset]
320-
321-
fold_no = fold_no + 1
322-
rep_no = rep_no + 1
290+
_run_task_get_arffcontent(model, task, class_labels)
323291

0 commit comments

Comments
 (0)