Skip to content

Commit 0b01581

Browse files
committed
fix prediction indexing
1 parent 292023e commit 0b01581

2 files changed

Lines changed: 16 additions & 8 deletions

File tree

openml/runs/functions.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -449,9 +449,9 @@ def _calculate_local_measure(sklearn_fn, openml_name):
449449

450450
if isinstance(task, (OpenMLClassificationTask, OpenMLLearningCurveTask)):
451451

452-
for i in range(0, len(test_indices)):
452+
for i, tst_idx in enumerate(test_indices):
453453

454-
arff_line = [rep_no, fold_no, sample_no, i] # type: List[Any]
454+
arff_line = [rep_no, fold_no, sample_no, tst_idx] # type: List[Any]
455455
for j, class_label in enumerate(task.class_labels):
456456
arff_line.append(proba_y[i][j])
457457

@@ -545,13 +545,19 @@ def get_runs(run_ids):
545545

546546

547547
@openml.utils.thread_safe_if_oslo_installed
548-
def get_run(run_id):
548+
def get_run(run_id: int, ignore_cache: bool = False) -> OpenMLRun:
549549
"""Gets run corresponding to run_id.
550550
551551
Parameters
552552
----------
553553
run_id : int
554554
555+
ignore_cache : bool
556+
Whether to ignore the cache. If ``true`` this will download and overwrite the run xml
557+
even if the requested run is already cached.
558+
559+
ignore_cache
560+
555561
Returns
556562
-------
557563
run : OpenMLRun
@@ -565,11 +571,13 @@ def get_run(run_id):
565571
os.makedirs(run_dir)
566572

567573
try:
568-
return _get_cached_run(run_id)
574+
if not ignore_cache:
575+
return _get_cached_run(run_id)
576+
else:
577+
raise OpenMLCacheException(message='dummy')
569578

570-
except (OpenMLCacheException):
571-
run_xml = openml._api_calls._perform_api_call("run/%d" % run_id,
572-
'get')
579+
except OpenMLCacheException:
580+
run_xml = openml._api_calls._perform_api_call("run/%d" % run_id, 'get')
573581
with io.open(run_file, "w", encoding='utf8') as fh:
574582
fh.write(run_xml)
575583

tests/test_runs/test_run_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _wait_for_processed_run(self, run_id, max_waiting_time_seconds):
7373
# time.time() works in seconds
7474
start_time = time.time()
7575
while time.time() - start_time < max_waiting_time_seconds:
76-
run = openml.runs.get_run(run_id)
76+
run = openml.runs.get_run(run_id, ignore_cache=True)
7777
if len(run.evaluations) > 0:
7878
return
7979
else:

0 commit comments

Comments
 (0)