Skip to content

Commit d949996

Browse files
authored
Merge pull request #253 from openml/add/#193
Allow running a flow on a task
2 parents ea4c9be + faf5b26 commit d949996

14 files changed

Lines changed: 564 additions & 204 deletions

File tree

openml/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from . import tasks
2222
from . import runs
2323
from . import flows
24+
from . import setups
2425
from .runs import OpenMLRun
2526
from .tasks import OpenMLTask, OpenMLSplit
2627
from .flows import OpenMLFlow
@@ -66,4 +67,4 @@ def populate_cache(task_ids=None, dataset_ids=None, flow_ids=None,
6667

6768
__all__ = ['OpenMLDataset', 'OpenMLDataFeature', 'OpenMLRun',
6869
'OpenMLSplit', 'datasets', 'OpenMLTask', 'OpenMLFlow',
69-
'config', 'runs', 'flows', 'tasks']
70+
'config', 'runs', 'flows', 'tasks', 'setups']

openml/exceptions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
class PyOpenMLError(Exception):
22
def __init__(self, message):
3+
self.message = message
34
super(PyOpenMLError, self).__init__(message)
45

6+
57
class OpenMLServerError(PyOpenMLError):
68
"""class for when something is really wrong on the server
79
(result did not parse to dict), contains unparsed error."""
810

911
def __init__(self, message):
10-
message = "OpenML Server error: " + message
1112
super(OpenMLServerError, self).__init__(message)
1213

1314
#
@@ -18,7 +19,6 @@ class OpenMLServerException(OpenMLServerError):
1819
def __init__(self, code, message, additional=None):
1920
self.code = code
2021
self.additional = additional
21-
message = "OpenML Server exception: " + message
2222
super(OpenMLServerException, self).__init__(message)
2323

2424

openml/flows/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .flow import OpenMLFlow
1+
from .flow import OpenMLFlow, _copy_server_fields
22

33
from .sklearn_converter import sklearn_to_flow, flow_to_sklearn, _check_n_jobs
44
from .functions import get_flow, list_flows, flow_exists, assert_flows_equal

openml/flows/flow.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,21 @@ def _from_dict(cls, xml_dict):
322322
arguments['tags'] = tags
323323

324324
arguments['model'] = None
325-
return cls(**arguments)
325+
flow = cls(**arguments)
326+
327+
# try to parse to a model because not everything that can be
328+
# deserialized has to come from scikit-learn. If it can't be
329+
# serialized, but comes from scikit-learn this is worth an exception
330+
try:
331+
from .sklearn_converter import flow_to_sklearn
332+
model = flow_to_sklearn(flow)
333+
except Exception as e:
334+
if arguments['external_version'].startswith('sklearn'):
335+
raise e
336+
model = None
337+
flow.model = model
338+
339+
return flow
326340

327341
def publish(self):
328342
"""Publish flow to OpenML server.
@@ -332,32 +346,42 @@ def publish(self):
332346
self : OpenMLFlow
333347
334348
"""
349+
# Import at top not possible because of cyclic dependencies. In
350+
# particular, flow.py tries to import functions.py in order to call
351+
# get_flow(), while functions.py tries to import flow.py in order to
352+
# instantiate an OpenMLFlow.
353+
import openml.flows.functions
335354

336355
xml_description = self._to_xml()
337356

338357
file_elements = {'description': xml_description}
339358
return_value = _perform_api_call("flow/", file_elements=file_elements)
340-
self.flow_id = int(xmltodict.parse(return_value)['oml:upload_flow']['oml:id'])
359+
flow_id = int(xmltodict.parse(return_value)['oml:upload_flow']['oml:id'])
360+
flow = openml.flows.functions.get_flow(flow_id)
361+
_copy_server_fields(flow, self)
341362
try:
342-
_check_flow(self)
363+
openml.flows.functions.assert_flows_equal(self, flow, flow.upload_date)
343364
except ValueError as e:
344365
message = e.args[0]
345366
raise ValueError("Flow was not stored correctly on the server. "
346367
"New flow ID is %d. Please check manually and "
347368
"remove the flow if necessary! Error is:\n'%s'" %
348-
(self.flow_id, message))
369+
(flow_id, message))
349370
return self
350371

351372

373+
def _copy_server_fields(source_flow, target_flow):
374+
fields_added_by_the_server = ['flow_id', 'uploader', 'version',
375+
'upload_date']
376+
for field in fields_added_by_the_server:
377+
setattr(target_flow, field, getattr(source_flow, field))
378+
379+
for name, component in source_flow.components.items():
380+
assert name in target_flow.components
381+
_copy_server_fields(component, target_flow.components[name])
382+
383+
352384
def _add_if_nonempty(dic, key, value):
353385
if value is not None:
354386
dic[key] = value
355387

356-
357-
def _check_flow(flow):
358-
# Import is not possible at the top of the file as this would cause an
359-
# ImportError due to an import cycle.
360-
import openml.flows.functions
361-
362-
flow_copy = openml.flows.functions.get_flow(flow.flow_id)
363-
openml.flows.functions.assert_flows_equal(flow, flow_copy)

openml/flows/functions.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import dateutil.parser
2+
13
import xmltodict
24
import six
35

46
from openml._api_calls import _perform_api_call
5-
from . import OpenMLFlow, flow_to_sklearn
7+
from . import OpenMLFlow
68

79

810
def get_flow(flow_id):
@@ -24,9 +26,6 @@ def get_flow(flow_id):
2426
flow_dict = xmltodict.parse(flow_xml)
2527
flow = OpenMLFlow._from_dict(flow_dict)
2628

27-
if 'sklearn' in flow.external_version:
28-
flow.model = flow_to_sklearn(flow)
29-
3029
return flow
3130

3231

@@ -130,11 +129,41 @@ def _list_flows(api_call):
130129
return flows
131130

132131

133-
def assert_flows_equal(flow1, flow2):
132+
def _check_flow_for_server_id(flow):
133+
"""Check if the given flow and it's components have a flow_id."""
134+
135+
# Depth-first search to check if all components were uploaded to the
136+
# server before parsing the parameters
137+
stack = list()
138+
stack.append(flow)
139+
while len(stack) > 0:
140+
current = stack.pop()
141+
if current.flow_id is None:
142+
raise ValueError("Flow %s has no flow_id!" % current.name)
143+
else:
144+
for component in current.components.values():
145+
stack.append(component)
146+
147+
148+
def assert_flows_equal(flow1, flow2, ignore_parameters_on_older_children=None,
149+
ignore_parameters=False):
134150
"""Check equality of two flows.
135151
136152
Two flows are equal if their all keys which are not set by the server
137153
are equal, as well as all their parameters and components.
154+
155+
Parameters
156+
----------
157+
flow1 : OpenMLFlow
158+
159+
flow2 : OpenMLFlow
160+
161+
ignore_parameters_on_older_children : str
162+
If set to ``OpenMLFlow.upload_date``, ignores parameters in a child
163+
flow if it's upload date predates the upload date of the parent flow.
164+
165+
ignore_parameters : bool
166+
Whether to ignore parameter values when comparing flows.
138167
"""
139168
if not isinstance(flow1, OpenMLFlow):
140169
raise TypeError('Argument 1 must be of type OpenMLFlow, but is %s' %
@@ -144,8 +173,9 @@ def assert_flows_equal(flow1, flow2):
144173
raise TypeError('Argument 2 must be of type OpenMLFlow, but is %s' %
145174
type(flow2))
146175

147-
generated_by_the_server = ['flow_id', 'uploader', 'version',
148-
'upload_date', ]
176+
# TODO as they are actually now saved during publish, it might be good to
177+
# check for the equality of these as well.
178+
generated_by_the_server = ['flow_id', 'uploader', 'version', 'upload_date']
149179
ignored_by_python_API = ['binary_url', 'binary_format', 'binary_md5',
150180
'model']
151181

@@ -162,9 +192,22 @@ def assert_flows_equal(flow1, flow2):
162192
if not name in attr2:
163193
raise ValueError('Component %s only available in '
164194
'argument2, but not in argument1.' % name)
165-
assert_flows_equal(attr1[name], attr2[name])
195+
assert_flows_equal(attr1[name], attr2[name],
196+
ignore_parameters_on_older_children,
197+
ignore_parameters)
166198

167199
else:
200+
if key == 'parameters':
201+
if ignore_parameters_on_older_children:
202+
upload_date_current_flow = dateutil.parser.parse(
203+
flow1.upload_date)
204+
upload_date_parent_flow = dateutil.parser.parse(
205+
ignore_parameters_on_older_children)
206+
if upload_date_current_flow < upload_date_parent_flow:
207+
continue
208+
elif ignore_parameters:
209+
continue
210+
168211
if attr1 != attr2:
169212
raise ValueError("Flow %s: values for attribute '%s' differ: "
170213
"'%s' vs '%s'." %

openml/runs/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .run import OpenMLRun
22
from .trace import OpenMLRunTrace, OpenMLTraceIteration
3-
from .functions import (run_task, get_run, list_runs, get_runs, get_run_trace,
4-
initialize_model_from_run, initialize_model_from_trace)
3+
from .functions import (run_model_on_task, run_flow_on_task, get_run, list_runs,
4+
get_runs, get_run_trace, initialize_model_from_run,
5+
initialize_model_from_trace)
56

6-
__all__ = ['OpenMLRun', 'run_task', 'get_run', 'list_runs', 'get_runs']
7+
__all__ = ['OpenMLRun', 'run_model_on_task', 'run_flow_on_task', 'get_run',
8+
'list_runs', 'get_runs']

openml/runs/functions.py

Lines changed: 64 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,18 @@
77
import warnings
88

99
import numpy as np
10-
import sklearn
10+
import sklearn.pipeline
1111
import six
1212
import xmltodict
1313

14+
import openml
1415
from ..exceptions import PyOpenMLError
1516
from .. import config
16-
from ..flows import sklearn_to_flow, get_flow, flow_exists, _check_n_jobs
17+
from ..flows import sklearn_to_flow, get_flow, flow_exists, _check_n_jobs, \
18+
_copy_server_fields
1719
from ..setups import setup_exists, initialize_model
1820
from ..exceptions import OpenMLCacheException, OpenMLServerException
19-
from .._api_calls import _perform_api_call, _file_id_to_url
21+
from .._api_calls import _perform_api_call
2022
from .run import OpenMLRun, _get_version_information
2123
from .trace import OpenMLRunTrace, OpenMLTraceIteration
2224

@@ -25,24 +27,42 @@
2527
# circular imports
2628

2729

28-
def run_task(task, model, avoid_duplicate_runs=True, flow_tags=None, seed=None):
29-
"""Performs a CV run on the dataset of the given task, using the split.
30+
def run_model_on_task(task, model, avoid_duplicate_runs=True, flow_tags=None,
31+
seed=None):
32+
"""See ``run_flow_on_task for a documentation."""
33+
34+
flow = sklearn_to_flow(model)
35+
36+
return run_flow_on_task(task=task, flow=flow,
37+
avoid_duplicate_runs=avoid_duplicate_runs,
38+
flow_tags=flow_tags, seed=seed)
39+
40+
41+
def run_flow_on_task(task, flow, avoid_duplicate_runs=True, flow_tags=None,
42+
seed=None):
43+
"""Run the model provided by the flow on the dataset defined by task.
44+
45+
Takes the flow and repeat information into account. In case a flow is not
46+
yet published, it is published after executing the run (requires
47+
internet connection).
3048
3149
Parameters
3250
----------
3351
task : OpenMLTask
3452
Task to perform.
3553
model : sklearn model
36-
a model which has a function fit(X,Y) and predict(X),
54+
A model which has a function fit(X,Y) and predict(X),
3755
all supervised estimators of scikit learn follow this definition of a model [1]
3856
[1](http://scikit-learn.org/stable/tutorial/statistical_inference/supervised_learning.html)
3957
avoid_duplicate_runs : bool
40-
if this flag is set to True, the run will throw an error if the
41-
setup/task combination is already present on the server.
58+
If this flag is set to True, the run will throw an error if the
59+
setup/task combination is already present on the server. Works only
60+
if the flow is already published on the server. This feature requires an
61+
internet connection.
4262
flow_tags : list(str)
43-
a list of tags that the flow should have at creation
63+
A list of tags that the flow should have at creation.
4464
seed: int
45-
the models that are not seeded will get this seed
65+
Models that are not seeded will get this seed.
4666
4767
Returns
4868
-------
@@ -51,23 +71,19 @@ def run_task(task, model, avoid_duplicate_runs=True, flow_tags=None, seed=None):
5171
"""
5272
if flow_tags is not None and not isinstance(flow_tags, list):
5373
raise ValueError("flow_tags should be list")
54-
# TODO move this into its onwn module. While it somehow belongs here, it
55-
# adds quite a lot of functionality which is better suited in other places!
56-
# TODO why doesn't this accept a flow as input? - this would make this more flexible!
57-
model = _get_seeded_model(model, seed)
58-
flow = sklearn_to_flow(model)
5974

60-
# returns flow id if the flow exists on the server, False otherwise
61-
flow_id = flow_exists(flow.name, flow.external_version)
75+
flow.model = _get_seeded_model(flow.model, seed=seed)
6276

6377
# skips the run if it already exists and the user opts for this in the config file.
6478
# also, if the flow is not present on the server, the check is not needed.
79+
flow_id = flow_exists(flow.name, flow.external_version)
6580
if avoid_duplicate_runs and flow_id:
66-
flow = get_flow(flow_id)
67-
setup_id = setup_exists(flow, model)
81+
flow_from_server = get_flow(flow_id)
82+
setup_id = setup_exists(flow_from_server)
6883
ids = _run_exists(task.task_id, setup_id)
6984
if ids:
7085
raise PyOpenMLError("Run already exists in server. Run id(s): %s" %str(ids))
86+
_copy_server_fields(flow_from_server, flow)
7187

7288
dataset = task.get_dataset()
7389

@@ -78,25 +94,43 @@ def run_task(task, model, avoid_duplicate_runs=True, flow_tags=None, seed=None):
7894

7995
run_environment = _get_version_information()
8096
tags = ['openml-python', run_environment[1]]
97+
8198
# execute the run
82-
run = OpenMLRun(task_id=task.task_id, flow_id=None, dataset_id=dataset.dataset_id, model=model, tags=tags)
83-
res = _run_task_get_arffcontent(model, task, class_labels)
84-
run.data_content, run.trace_content, run.trace_attributes, run.detailed_evaluations = res
99+
res = _run_task_get_arffcontent(flow.model, task, class_labels)
85100

86-
if flow_id == False:
87-
# means the flow did not exists. As we could run it, publish it now
88-
flow = flow.publish()
89-
else:
90-
# flow already existed, download it from server
91-
# TODO (neccessary? is this a post condition of this function)
92-
flow = get_flow(flow_id)
101+
if flow.flow_id is None:
102+
_publish_flow_if_necessary(flow)
103+
104+
run = OpenMLRun(task_id=task.task_id, flow_id=flow.flow_id,
105+
dataset_id=dataset.dataset_id, model=flow.model, tags=tags)
106+
run.parameter_settings = OpenMLRun._parse_parameters(flow)
107+
108+
run.data_content, run.trace_content, run.trace_attributes, run.detailed_evaluations = res
93109

94-
run.flow_id = flow.flow_id
95110
config.logger.info('Executed Task %d with Flow id: %d' % (task.task_id, run.flow_id))
96111

97112
return run
98113

99114

115+
def _publish_flow_if_necessary(flow):
116+
# try publishing the flow if one has to assume it doesn't exist yet. It
117+
# might fail because it already exists, then the flow is currently not
118+
# reused
119+
120+
try:
121+
flow.publish()
122+
except OpenMLServerException as e:
123+
if e.message == "flow already exists":
124+
flow_id = openml.flows.flow_exists(flow.name,
125+
flow.external_version)
126+
server_flow = get_flow(flow_id)
127+
openml.flows.flow._copy_server_fields(server_flow, flow)
128+
openml.flows.assert_flows_equal(flow, server_flow,
129+
ignore_parameters=True)
130+
else:
131+
raise e
132+
133+
100134
def get_run_trace(run_id):
101135
"""Get the optimization trace object for a given run id.
102136

0 commit comments

Comments
 (0)