Skip to content

Commit be15112

Browse files
committed
small updates for unit testing
1 parent 8435921 commit be15112

3 files changed

Lines changed: 20 additions & 7 deletions

File tree

openml/runs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .run import OpenMLRun
22
from .trace import OpenMLRunTrace, OpenMLTraceIteration
3-
from .functions import (run_task, get_run, list_runs, get_runs,
3+
from .functions import (run_task, get_run, list_runs, get_runs, get_run_trace,
44
initialize_model_from_run, initialize_model_from_trace)
55

66
__all__ = ['OpenMLRun', 'run_task', 'get_run', 'list_runs', 'get_runs']

openml/runs/functions.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@ def run_task(task, model, avoid_duplicate_runs=True, flow_tags=None, seed=None):
9797

9898
return run
9999

100+
101+
def get_run_trace(run_id):
102+
trace_xml = _perform_api_call('run/trace/%d' % run_id)
103+
run_trace = _create_trace_from_description(trace_xml)
104+
return run_trace
105+
106+
100107
def initialize_model_from_run(run_id):
101108
'''
102109
Initialized a model based on a run_id (i.e., using the exact
@@ -133,18 +140,18 @@ def initialize_model_from_trace(run_id, repeat, fold, iteration=None):
133140
The fold nr (column in trace file)
134141
135142
iteration: int
136-
The iteration nr (column in trace file)
143+
The iteration nr (column in trace file). If None, the
144+
best (selected) iteration will be searched (slow)
137145
138146
Returns
139147
-------
140148
model : sklearn model
141149
the scikitlearn model with all parameters initailized
142150
'''
143-
run = get_run(run_id)
144-
if 'trace' not in run.output_files:
145-
raise PyOpenMLError('Run does not contain trace file')
146-
trace_xml = _perform_api_call('run/trace/%d' %run_id)
147-
run_trace = _create_trace_from_description(trace_xml)
151+
run_trace = get_run_trace(run_id)
152+
153+
if iteration is None:
154+
iteration = run_trace.get_selected_iteration(repeat, fold)
148155

149156
request = (repeat, fold, iteration)
150157
if request not in run_trace.trace_iterations:

openml/runs/trace.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ def __init__(self, run_id, trace_iterations):
1313
self.run_id = run_id
1414
self.trace_iterations = trace_iterations
1515

16+
def get_selected_iteration(self, fold, repeat):
17+
for (r, f, i) in self.trace_iterations:
18+
if r == repeat and f == fold and self.trace_iterations[(r, f, i)].selected == True:
19+
return i
20+
raise ValueError('Could not find the selected iteration for rep/fold %d/%d' %(repeat,fold))
21+
1622
def __str__(self):
1723
return '[Run id: %d, %d trace iterations]' %(self.run_id, len(self.trace_iterations))
1824

0 commit comments

Comments
 (0)