Skip to content

Commit 2228059

Browse files
committed
fix test & pep8 & mypy
1 parent 4e971f4 commit 2228059

6 files changed

Lines changed: 19 additions & 11 deletions

File tree

openml/extensions/extension_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
if TYPE_CHECKING:
1111
from openml.flows import OpenMLFlow
1212
from openml.tasks.task import OpenMLTask
13-
from openml.runs.trace import OpenMLRunTrace, OpenMLTraceIteration
13+
from openml.runs.trace import OpenMLRunTrace, OpenMLTraceIteration # noqa F401
1414

1515

1616
class Extension(ABC):

openml/extensions/sklearn/extension.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1268,8 +1268,8 @@ def _prediction_to_probabilities(
12681268

12691269
if proba_y.shape[1] != len(task.class_labels):
12701270
message = "Estimator only predicted for {}/{} classes!".format(
1271-
proba_y.shape[1], len(task.class_labels),
1272-
)
1271+
proba_y.shape[1], len(task.class_labels),
1272+
)
12731273
warnings.warn(message)
12741274
openml.config.logger.warn(message)
12751275

@@ -1284,7 +1284,7 @@ def _prediction_to_probabilities(
12841284

12851285
if self._is_hpo_class(model_copy):
12861286
trace_data = self._extract_trace_data(model_copy, rep_no, fold_no)
1287-
trace = self._obtain_arff_trace(model_copy, trace_data)
1287+
trace = self._obtain_arff_trace(model_copy, trace_data) # type: Optional[OpenMLRunTrace] # noqa E501
12881288
else:
12891289
trace = None
12901290

openml/runs/functions.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Any, List, Optional, Set, Tuple, Union, TYPE_CHECKING # noqa F401
55
import warnings
66

7-
import numpy as np
87
import sklearn.metrics
98
import xmltodict
109

@@ -382,7 +381,6 @@ def _run_task_get_arffcontent(
382381
'OrderedDict[str, OrderedDict]',
383382
]:
384383
arff_datacontent = [] # type: List[List]
385-
arff_tracecontent = [] # type: List[List]
386384
traces = [] # type: List[OpenMLRunTrace]
387385
# stores fold-based evaluation measures. In case of a sample based task,
388386
# this information is multiple times overwritten, but due to the ordering

openml/runs/trace.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from collections import OrderedDict
22
import json
33
import os
4-
from typing import List
4+
from typing import List, Tuple # noqa F401
55

66
import arff
77
import xmltodict
@@ -346,23 +346,22 @@ def trace_from_xml(cls, xml):
346346
)
347347
trace[(repeat, fold, iteration)] = current
348348

349-
return cls(None, trace)
349+
return cls(run_id, trace)
350350

351351
@classmethod
352352
def merge_traces(cls, traces: List['OpenMLRunTrace']):
353353
for i in range(1, len(traces)):
354354
if traces[i] != traces[i - 1]:
355355
raise ValueError('Cannot merge traces!')
356356

357-
merged_trace = OrderedDict()
357+
merged_trace = OrderedDict() # type: OrderedDict[Tuple[int, int, int], OpenMLTraceIteration] # noqa E501
358358

359359
for trace in traces:
360360
for iteration in trace:
361361
merged_trace[(iteration.repeat, iteration.fold, iteration.iteration)] = iteration
362362

363363
return cls(None, merged_trace)
364364

365-
366365
def __str__(self):
367366
return '[Run id: %d, %d trace iterations]' % (
368367
-1 if self.run_id is None else self.run_id,

tests/test_extensions/test_sklearn_extension/test_sklearn_extension.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1264,7 +1264,7 @@ def test_run_model_on_fold(self):
12641264
# TODO add some mocking here to actually test the innards of this function, too!
12651265
res = self.extension._run_model_on_fold(
12661266
clf, task, 0, 0, 0,
1267-
add_local_measures=True)
1267+
)
12681268

12691269
arff_datacontent, arff_tracecontent, user_defined_measures, model = res
12701270
# predictions

tests/test_runs/test_run_functions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import random
55
import time
66
import sys
7+
import unittest.mock
78

89
import numpy as np
910

@@ -1052,8 +1053,11 @@ def test__run_task_get_arffcontent(self):
10521053
num_folds = 10
10531054
num_repeats = 1
10541055

1056+
flow = unittest.mock.Mock()
1057+
flow.name = 'dummy'
10551058
clf = SGDClassifier(loss='log', random_state=1)
10561059
res = openml.runs.functions._run_task_get_arffcontent(
1060+
flow=flow,
10571061
extension=self.extension,
10581062
model=clf,
10591063
task=task,
@@ -1246,12 +1250,15 @@ def test_run_on_dataset_with_missing_labels(self):
12461250
# labels only declared in the arff file, but is not present in the
12471251
# actual data
12481252

1253+
flow = unittest.mock.Mock()
1254+
flow.name = 'dummy'
12491255
task = openml.tasks.get_task(2)
12501256

12511257
model = Pipeline(steps=[('Imputer', Imputer(strategy='median')),
12521258
('Estimator', DecisionTreeClassifier())])
12531259

12541260
data_content, _, _, _ = _run_task_get_arffcontent(
1261+
flow=flow,
12551262
model=model,
12561263
task=task,
12571264
extension=self.extension,
@@ -1267,6 +1274,8 @@ def test_run_on_dataset_with_missing_labels(self):
12671274
def test_predict_proba_hardclassifier(self):
12681275
# task 1 (test server) is important: it is a task with an unused class
12691276
tasks = [1, 3, 115]
1277+
flow = unittest.mock.Mock()
1278+
flow.name = 'dummy'
12701279

12711280
for task_id in tasks:
12721281
task = openml.tasks.get_task(task_id)
@@ -1280,12 +1289,14 @@ def test_predict_proba_hardclassifier(self):
12801289
])
12811290

12821291
arff_content1, _, _, _ = _run_task_get_arffcontent(
1292+
flow=flow,
12831293
model=clf1,
12841294
task=task,
12851295
extension=self.extension,
12861296
add_local_measures=True,
12871297
)
12881298
arff_content2, _, _, _ = _run_task_get_arffcontent(
1299+
flow=flow,
12891300
model=clf2,
12901301
task=task,
12911302
extension=self.extension,

0 commit comments

Comments
 (0)