Skip to content

Commit f55f58e

Browse files
committed
MAINT improve test coverage for run.py
1 parent 244c585 commit f55f58e

2 files changed

Lines changed: 53 additions & 18 deletions

File tree

openml/runs/run.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ class OpenMLRun(object):
2323
FIXME
2424
2525
"""
26-
def __init__(self, task_id, flow_id, setup_string, dataset_id, files=None,
27-
setup_id=None, tags=None, run_id=None, uploader=None,
28-
uploader_name=None, evaluations=None,
26+
def __init__(self, task_id, flow_id, dataset_id, setup_string=None,
27+
files=None, setup_id=None, tags=None, run_id=None,
28+
uploader=None, uploader_name=None, evaluations=None,
2929
detailed_evaluations=None, data_content=None,
3030
model=None, task_type=None, task_evaluation_measure=None,
3131
flow_name=None, parameter_settings=None, predictions_url=None):
@@ -47,8 +47,8 @@ def __init__(self, task_id, flow_id, setup_string, dataset_id, files=None,
4747
self.data_content = data_content
4848
self.model = model
4949

50-
def _generate_arff(self):
51-
"""Generates an arff for upload to server.
50+
def _generate_arff_header_dict(self):
51+
"""Generates the arff header dictionary for upload to the server.
5252
5353
Returns
5454
-------
@@ -78,7 +78,7 @@ def publish(self):
7878
7979
Uploads the results of a run to OpenML.
8080
"""
81-
predictions = arff.dumps(self._generate_arff())
81+
predictions = arff.dumps(self._generate_arff_header_dict())
8282
description_xml = self._create_description_xml()
8383
file_elements = {'predictions': ("predictions.csv", predictions),
8484
'description': ("description.xml", description_xml)}
@@ -152,8 +152,18 @@ def run_task(task, model):
152152
setup_string = _create_setup_string(model)
153153

154154
run = OpenMLRun(task.task_id, flow_id, setup_string, dataset.id)
155+
run.data_content = _run_task_get_arffcontent(model, task, class_labels)
155156

156-
train_times = []
157+
# The model will not be uploaded at the moment, but used to get the
158+
# hyperparameter values when uploading the run
159+
X, Y = task.get_X_and_y()
160+
run.model = model.fit(X, Y)
161+
return run
162+
163+
164+
def _run_task_get_arffcontent(model, task, class_labels):
165+
X, Y = task.get_X_and_y()
166+
arff_datacontent = []
157167

158168
rep_no = 0
159169
# TODO use different iterator to only provide a single iterator (less
@@ -167,26 +177,21 @@ def run_task(task, model):
167177
testX = X[test_indices]
168178
testY = Y[test_indices]
169179

170-
start_time = time.time()
171180
model.fit(trainX, trainY)
172181
ProbaY = model.predict_proba(testX)
173182
PredY = model.predict(testX)
174-
end_time = time.time()
175-
176-
train_times.append(end_time - start_time)
177183

178184
for i in range(0, len(test_indices)):
179-
arff_line = [rep_no, fold_no, test_indices[i],
180-
class_labels[PredY[i]], class_labels[testY[i]]]
181-
arff_line[3:3] = ProbaY[i]
185+
arff_line = [rep_no, fold_no, test_indices[i]]
186+
arff_line.extend(ProbaY[i])
187+
arff_line.append(class_labels[PredY[i]])
188+
arff_line.append(class_labels[testY[i]])
182189
arff_datacontent.append(arff_line)
183190

184191
fold_no = fold_no + 1
185192
rep_no = rep_no + 1
186193

187-
run.data_content = arff_datacontent
188-
run.model = model.fit(X, Y)
189-
return run
194+
return arff_datacontent
190195

191196

192197
def _to_dict(taskid, flow_id, setup_string, parameter_settings, tags):

tests/runs/test_runs.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sklearn.linear_model import LogisticRegression
1+
from sklearn.linear_model import LogisticRegression, SGDClassifier
22
import openml
33
from openml.testing import TestBase
44

@@ -12,6 +12,36 @@ def test_run_iris(self):
1212
self.assertEqual(return_code, 200)
1313
# self.assertTrue("This is a read-only account" in return_value)
1414

15+
def test__run_task_get_arffcontent(self):
16+
task = openml.tasks.get_task(1939)
17+
class_labels = task.class_labels
18+
19+
clf = SGDClassifier(loss='hinge', random_state=1)
20+
self.assertRaisesRegex(AttributeError,
21+
"probability estimates are not available for loss='hinge'",
22+
openml.runs.run._run_task_get_arffcontent,
23+
clf, task, class_labels)
24+
25+
clf = SGDClassifier(loss='log', random_state=1)
26+
arff_datacontent = openml.runs.run._run_task_get_arffcontent(
27+
clf, task, class_labels)
28+
self.assertIsInstance(arff_datacontent, list)
29+
# 10 times 10 fold CV of 150 samples
30+
self.assertEqual(len(arff_datacontent), 1500)
31+
for arff_line in arff_datacontent:
32+
self.assertEqual(len(arff_line), 8)
33+
self.assertGreaterEqual(arff_line[0], 0)
34+
self.assertLessEqual(arff_line[0], 9)
35+
self.assertGreaterEqual(arff_line[1], 0)
36+
self.assertLessEqual(arff_line[1], 9)
37+
self.assertGreaterEqual(arff_line[2], 0)
38+
self.assertLessEqual(arff_line[2], 149)
39+
self.assertAlmostEqual(sum(arff_line[3:6]), 1.0)
40+
self.assertIn(arff_line[6], ['Iris-setosa', 'Iris-versicolor',
41+
'Iris-virginica'])
42+
self.assertIn(arff_line[7], ['Iris-setosa', 'Iris-versicolor',
43+
'Iris-virginica'])
44+
1545
def test_get_run(self):
1646
run = openml.runs.get_run(473350)
1747
self.assertEqual(run.dataset_id, 1167)

0 commit comments

Comments
 (0)