Skip to content

Commit 8fadddc

Browse files
committed
split function run_task into two functions
also parse parameters when running a flow on a task fix publishing error
1 parent ea4c9be commit 8fadddc

11 files changed

Lines changed: 230 additions & 135 deletions

File tree

openml/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from . import tasks
2222
from . import runs
2323
from . import flows
24+
from . import setups
2425
from .runs import OpenMLRun
2526
from .tasks import OpenMLTask, OpenMLSplit
2627
from .flows import OpenMLFlow
@@ -66,4 +67,4 @@ def populate_cache(task_ids=None, dataset_ids=None, flow_ids=None,
6667

6768
__all__ = ['OpenMLDataset', 'OpenMLDataFeature', 'OpenMLRun',
6869
'OpenMLSplit', 'datasets', 'OpenMLTask', 'OpenMLFlow',
69-
'config', 'runs', 'flows', 'tasks']
70+
'config', 'runs', 'flows', 'tasks', 'setups']

openml/flows/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .flow import OpenMLFlow
1+
from .flow import OpenMLFlow, _copy_server_fields
22

33
from .sklearn_converter import sklearn_to_flow, flow_to_sklearn, _check_n_jobs
44
from .functions import get_flow, list_flows, flow_exists, assert_flows_equal

openml/flows/flow.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,16 @@ def _from_dict(cls, xml_dict):
322322
arguments['tags'] = tags
323323

324324
arguments['model'] = None
325-
return cls(**arguments)
325+
flow = cls(**arguments)
326+
327+
if 'sklearn' in arguments['external_version']:
328+
from .sklearn_converter import flow_to_sklearn
329+
model = flow_to_sklearn(flow)
330+
else:
331+
model = None
332+
flow.model = model
333+
334+
return flow
326335

327336
def publish(self):
328337
"""Publish flow to OpenML server.
@@ -332,32 +341,38 @@ def publish(self):
332341
self : OpenMLFlow
333342
334343
"""
344+
import openml.flows.functions
335345

336346
xml_description = self._to_xml()
337347

338348
file_elements = {'description': xml_description}
339349
return_value = _perform_api_call("flow/", file_elements=file_elements)
340-
self.flow_id = int(xmltodict.parse(return_value)['oml:upload_flow']['oml:id'])
350+
flow_id = int(xmltodict.parse(return_value)['oml:upload_flow']['oml:id'])
351+
flow = openml.flows.functions.get_flow(flow_id)
341352
try:
342-
_check_flow(self)
353+
openml.flows.functions.assert_flows_equal(self, flow)
343354
except ValueError as e:
344355
message = e.args[0]
345356
raise ValueError("Flow was not stored correctly on the server. "
346357
"New flow ID is %d. Please check manually and "
347358
"remove the flow if necessary! Error is:\n'%s'" %
348-
(self.flow_id, message))
359+
(flow_id, message))
360+
_copy_server_fields(flow, self)
349361
return self
350362

351363

364+
def _copy_server_fields(source_flow, target_flow):
365+
fields_added_by_the_server = ['flow_id', 'uploader', 'version',
366+
'upload_date']
367+
for field in fields_added_by_the_server:
368+
setattr(target_flow, field, getattr(source_flow, field))
369+
370+
for name, component in source_flow.components.items():
371+
assert name in target_flow.components
372+
_copy_server_fields(component, target_flow.components[name])
373+
374+
352375
def _add_if_nonempty(dic, key, value):
353376
if value is not None:
354377
dic[key] = value
355378

356-
357-
def _check_flow(flow):
358-
# Import is not possible at the top of the file as this would cause an
359-
# ImportError due to an import cycle.
360-
import openml.flows.functions
361-
362-
flow_copy = openml.flows.functions.get_flow(flow.flow_id)
363-
openml.flows.functions.assert_flows_equal(flow, flow_copy)

openml/flows/functions.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import six
33

44
from openml._api_calls import _perform_api_call
5-
from . import OpenMLFlow, flow_to_sklearn
5+
from . import OpenMLFlow
66

77

88
def get_flow(flow_id):
@@ -24,9 +24,6 @@ def get_flow(flow_id):
2424
flow_dict = xmltodict.parse(flow_xml)
2525
flow = OpenMLFlow._from_dict(flow_dict)
2626

27-
if 'sklearn' in flow.external_version:
28-
flow.model = flow_to_sklearn(flow)
29-
3027
return flow
3128

3229

@@ -144,8 +141,7 @@ def assert_flows_equal(flow1, flow2):
144141
raise TypeError('Argument 2 must be of type OpenMLFlow, but is %s' %
145142
type(flow2))
146143

147-
generated_by_the_server = ['flow_id', 'uploader', 'version',
148-
'upload_date', ]
144+
generated_by_the_server = ['flow_id', 'uploader', 'version', 'upload_date']
149145
ignored_by_python_API = ['binary_url', 'binary_format', 'binary_md5',
150146
'model']
151147

openml/runs/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .run import OpenMLRun
22
from .trace import OpenMLRunTrace, OpenMLTraceIteration
3-
from .functions import (run_task, get_run, list_runs, get_runs, get_run_trace,
4-
initialize_model_from_run, initialize_model_from_trace)
3+
from .functions import (run_model_on_task, run_flow_on_task, get_run, list_runs,
4+
get_runs, get_run_trace, initialize_model_from_run,
5+
initialize_model_from_trace)
56

6-
__all__ = ['OpenMLRun', 'run_task', 'get_run', 'list_runs', 'get_runs']
7+
__all__ = ['OpenMLRun', 'run_model_on_task', 'run_flow_on_task', 'get_run',
8+
'list_runs', 'get_runs']

openml/runs/functions.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313

1414
from ..exceptions import PyOpenMLError
1515
from .. import config
16-
from ..flows import sklearn_to_flow, get_flow, flow_exists, _check_n_jobs
16+
from ..flows import sklearn_to_flow, get_flow, flow_exists, _check_n_jobs, \
17+
_copy_server_fields
1718
from ..setups import setup_exists, initialize_model
1819
from ..exceptions import OpenMLCacheException, OpenMLServerException
19-
from .._api_calls import _perform_api_call, _file_id_to_url
20+
from .._api_calls import _perform_api_call
2021
from .run import OpenMLRun, _get_version_information
2122
from .trace import OpenMLRunTrace, OpenMLTraceIteration
2223

@@ -25,7 +26,31 @@
2526
# circular imports
2627

2728

28-
def run_task(task, model, avoid_duplicate_runs=True, flow_tags=None, seed=None):
29+
def run_model_on_task(task, model, avoid_duplicate_runs=True, flow_tags=None,
30+
seed=None):
31+
flow = sklearn_to_flow(model)
32+
33+
# returns flow id if the flow exists on the server, False otherwise
34+
flow_id = flow_exists(flow.name, flow.external_version)
35+
36+
if flow_id == False:
37+
# TODO this is potential race condition! someone could upload the
38+
# same flow in the meantime!
39+
# means the flow did not exists. As we could run it, publish it now
40+
flow = flow.publish()
41+
else:
42+
# flow already existed, download it from server
43+
# TODO (neccessary? is this a post condition of this function)
44+
flow_from_server = get_flow(flow_id)
45+
_copy_server_fields(flow_from_server, flow)
46+
47+
return run_flow_on_task(task=task, flow=flow,
48+
avoid_duplicate_runs=avoid_duplicate_runs,
49+
flow_tags=flow_tags, seed=seed)
50+
51+
52+
def run_flow_on_task(task, flow, avoid_duplicate_runs=True, flow_tags=None,
53+
seed=None):
2954
"""Performs a CV run on the dataset of the given task, using the split.
3055
3156
Parameters
@@ -51,23 +76,18 @@ def run_task(task, model, avoid_duplicate_runs=True, flow_tags=None, seed=None):
5176
"""
5277
if flow_tags is not None and not isinstance(flow_tags, list):
5378
raise ValueError("flow_tags should be list")
54-
# TODO move this into its onwn module. While it somehow belongs here, it
55-
# adds quite a lot of functionality which is better suited in other places!
56-
# TODO why doesn't this accept a flow as input? - this would make this more flexible!
57-
model = _get_seeded_model(model, seed)
58-
flow = sklearn_to_flow(model)
5979

60-
# returns flow id if the flow exists on the server, False otherwise
61-
flow_id = flow_exists(flow.name, flow.external_version)
80+
flow.model = _get_seeded_model(flow.model, seed=seed)
6281

6382
# skips the run if it already exists and the user opts for this in the config file.
6483
# also, if the flow is not present on the server, the check is not needed.
65-
if avoid_duplicate_runs and flow_id:
66-
flow = get_flow(flow_id)
67-
setup_id = setup_exists(flow, model)
84+
if avoid_duplicate_runs:
85+
flow_from_server = get_flow(flow.flow_id)
86+
setup_id = setup_exists(flow_from_server)
6887
ids = _run_exists(task.task_id, setup_id)
6988
if ids:
7089
raise PyOpenMLError("Run already exists in server. Run id(s): %s" %str(ids))
90+
_copy_server_fields(flow_from_server, flow)
7191

7292
dataset = task.get_dataset()
7393

@@ -79,18 +99,12 @@ def run_task(task, model, avoid_duplicate_runs=True, flow_tags=None, seed=None):
7999
run_environment = _get_version_information()
80100
tags = ['openml-python', run_environment[1]]
81101
# execute the run
82-
run = OpenMLRun(task_id=task.task_id, flow_id=None, dataset_id=dataset.dataset_id, model=model, tags=tags)
83-
res = _run_task_get_arffcontent(model, task, class_labels)
102+
run = OpenMLRun(task_id=task.task_id, flow_id=None, dataset_id=dataset.dataset_id,
103+
model=flow.model, tags=tags)
104+
run.parameter_settings = OpenMLRun._parse_parameters(flow)
105+
res = _run_task_get_arffcontent(flow.model, task, class_labels)
84106
run.data_content, run.trace_content, run.trace_attributes, run.detailed_evaluations = res
85107

86-
if flow_id == False:
87-
# means the flow did not exists. As we could run it, publish it now
88-
flow = flow.publish()
89-
else:
90-
# flow already existed, download it from server
91-
# TODO (neccessary? is this a post condition of this function)
92-
flow = get_flow(flow_id)
93-
94108
run.flow_id = flow.flow_id
95109
config.logger.info('Executed Task %d with Flow id: %d' % (task.task_id, run.flow_id))
96110

openml/runs/run.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import OrderedDict
2+
import json
23
import sys
34
import time
45

@@ -143,11 +144,6 @@ def _create_description_xml(self):
143144
XML description of run.
144145
"""
145146

146-
# TODO: don't we have flow object in data structure? Use this one
147-
downloaded_flow = openml.flows.get_flow(self.flow_id)
148-
149-
openml_param_settings = OpenMLRun._parse_parameters(self.model, downloaded_flow)
150-
151147
# as a tag, it must be of the form ([a-zA-Z0-9_\-\.])+
152148
# so we format time from 'mm/dd/yy hh:mm:ss' to 'mm-dd-yy_hh.mm.ss'
153149
# well_formatted_time = time.strftime("%c").replace(
@@ -156,56 +152,82 @@ def _create_description_xml(self):
156152
# [self.model.__module__ + "." + self.model.__class__.__name__]
157153
description = _to_dict(taskid=self.task_id, flow_id=self.flow_id,
158154
setup_string=_create_setup_string(self.model),
159-
parameter_settings=openml_param_settings,
155+
parameter_settings=self.parameter_settings,
160156
error_message=self.error_message,
161157
detailed_evaluations=self.detailed_evaluations,
162158
tags=self.tags)
163159
description_xml = xmltodict.unparse(description, pretty=True)
164160
return description_xml
165161

166162
@staticmethod
167-
def _parse_parameters(model, server_flow):
168-
"""Extracts all parameter settings from a model in OpenML format.
163+
def _parse_parameters(flow):
164+
"""Extracts all parameter settings from the model inside a flow in
165+
OpenML format.
169166
170167
Parameters
171168
----------
172-
model
173-
the scikit-learn model (fitted)
174169
flow
175170
openml flow object (containing flow ids, i.e., it has to be downloaded from the server)
176171
177172
"""
178-
if server_flow.flow_id is None:
179-
raise ValueError("The flow parameter needs to be downloaded from server")
173+
174+
# Depth-first search to check if all components were uploaded to the
175+
# server before parsing the parameters
176+
stack = list()
177+
stack.append(flow)
178+
while len(stack) > 0:
179+
current = stack.pop()
180+
if current.flow_id is None:
181+
raise ValueError("Flow %s has no flow_id!" % current.name)
182+
else:
183+
for component in current.components.values():
184+
stack.append(component)
180185

181186
def get_flow_dict(_flow):
182187
flow_map = {_flow.name: _flow.flow_id}
183188
for subflow in _flow.components:
184189
flow_map.update(get_flow_dict(_flow.components[subflow]))
185190
return flow_map
186191

187-
def extract_parameters(_flow, _param_dict, _main_call=False, main_id=None):
192+
def extract_parameters(_flow, _flow_dict, _main_call=False, main_id=None):
188193
# _flow is openml flow object, _param dict maps from flow name to flow id
189194
# for the main call, the param dict can be overridden (useful for unit tests / sentinels)
190-
# this way, for flows without subflows we do not have to rely on _param_dict
195+
# this way, for flows without subflows we do not have to rely on _flow_dict
191196
_params = []
192197
for _param_name in _flow.parameters:
193198
_current = OrderedDict()
194199
_current['oml:name'] = _param_name
195-
_current['oml:value'] = _flow.parameters[_param_name]
200+
201+
_tmp = openml.flows.sklearn_to_flow(_flow.model.get_params()[_param_name])
202+
203+
# Try to filter out components which are handled further down!
204+
if isinstance(_tmp, openml.flows.OpenMLFlow):
205+
continue
206+
try:
207+
_tmp = json.dumps(_tmp)
208+
except TypeError as e:
209+
# Python3.5 exception message:
210+
# <openml.flows.flow.OpenMLFlow object at 0x7fed87978160> is not JSON serializable
211+
# Python3.6 exception message:
212+
# Object of type 'OpenMLFlow' is not JSON serializable
213+
if 'OpenMLFlow' in e.args[0] and \
214+
'is not JSON serializable' in e.args[0]:
215+
continue
216+
217+
_current['oml:value'] = _tmp
196218
if _main_call:
197219
_current['oml:component'] = main_id
198220
else:
199-
_current['oml:component'] = _param_dict[_flow.name]
221+
_current['oml:component'] = _flow_dict[_flow.name]
200222
_params.append(_current)
223+
201224
for _identifier in _flow.components:
202-
_params.extend(extract_parameters(_flow.components[_identifier], _param_dict))
225+
_params.extend(extract_parameters(_flow.components[_identifier], _flow_dict))
203226
return _params
204227

205-
flow_dict = get_flow_dict(server_flow)
206-
local_flow = openml.flows.sklearn_to_flow(model)
228+
flow_dict = get_flow_dict(flow)
229+
parameters = extract_parameters(flow, flow_dict, True, flow.flow_id)
207230

208-
parameters = extract_parameters(local_flow, flow_dict, True, server_flow.flow_id)
209231
return parameters
210232

211233
################################################################################

0 commit comments

Comments
 (0)