Skip to content

Commit 4882cda

Browse files
committed
ADD WIP conversion from sklearn model to OpenMLFlow
1 parent c40edf1 commit 4882cda

4 files changed

Lines changed: 579 additions & 2 deletions

File tree

openml/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,10 @@ class OpenMLCacheException(PyOpenMLError):
1414
"""Dataset / task etc not found in cache"""
1515
def __init__(self, message):
1616
super(OpenMLCacheException, self).__init__(message)
17+
18+
19+
class OpenMLRestrictionViolated(PyOpenMLError):
20+
"""Flows for example have a maximum number of 128
21+
https://github.com/openml/OpenML/issues/283#issuecomment-216879769)"""
22+
def __init__(self, message):
23+
super(OpenMLRestrictionViolated, self).__init__(message)

openml/flows/flow.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,30 @@ class OpenMLFlow(object):
2929
3030
"""
3131
def __init__(self, model, flow_id=None, uploader=None,
32-
description='Flow generated by run_task', creator=None,
33-
contributor=None, tag=None):
32+
description=None, creator=None, components=None,
33+
parameters=None, contributor=None, tag=None):
3434
self.flow_id = flow_id
3535
self.upoader = uploader
3636
self.description = description
3737
self.creator = creator
3838
self.tag = tag
3939
self.model = model
40+
41+
# TODO update these - the sklearn transformation class should be able
42+
# to do this!
4043
self.source = "FIXME DEFINE PYTHON FLOW"
4144
self.name = (model.__module__ + "." +
4245
model.__class__.__name__)
4346
self.external_version = 'sklearn_' + sklearn.__version__
4447

48+
if components is None:
49+
components = []
50+
self.components = components
51+
if parameters is None:
52+
parameters = []
53+
self.parameters = parameters
54+
55+
4556
def _generate_flow_xml(self):
4657
"""Generate xml representation of self for upload to server.
4758

openml/flows/sklearn.py

Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
from collections import OrderedDict
2+
import importlib
3+
import inspect
4+
import six
5+
import warnings
6+
7+
import numpy as np
8+
import scipy.stats.distributions
9+
import sklearn.base
10+
import sklearn.model_selection
11+
# Necessary to have signature available in python 2.7
12+
from sklearn.utils.fixes import signature
13+
14+
from .flow import OpenMLFlow
15+
from ..exceptions import OpenMLRestrictionViolated
16+
17+
MAXIMAL_FLOW_LENGTH = 1024
18+
19+
20+
def serialize_object(o):
21+
if _is_estimator(o) or _is_transformer(o):
22+
rval = serialize_model(o)
23+
elif isinstance(o, (list, tuple)):
24+
rval = [serialize_object(element) for element in o]
25+
if isinstance(o, tuple):
26+
rval = tuple(rval)
27+
elif o is None:
28+
rval = None
29+
elif isinstance(o, six.string_types):
30+
rval = o
31+
elif isinstance(o, int):
32+
rval = o
33+
elif isinstance(o, float):
34+
rval = o
35+
elif isinstance(o, dict):
36+
rval = {}
37+
for key, value in o.items():
38+
if not isinstance(key, six.string_types):
39+
raise TypeError('Can only use string as keys, you passed '
40+
'type %s for value %s.' % (type(key), str(key)))
41+
key = serialize_object(key)
42+
value = serialize_object(value)
43+
rval[key] = value
44+
elif isinstance(o, type):
45+
rval = serialize_type(o)
46+
elif isinstance(o, scipy.stats.distributions.rv_frozen):
47+
rval = serialize_rv_frozen(o)
48+
# This only works for user-defined functions (and not even partial).
49+
# I think this exactly we want here as there shouldn't be any built-in or
50+
# functool.partials in a pipeline
51+
elif inspect.isfunction(o):
52+
rval = serialize_function(o)
53+
elif _is_cross_validator(o):
54+
rval = serialize_cross_validator(o)
55+
else:
56+
raise TypeError(o)
57+
58+
assert o is None or rval is not None
59+
60+
return rval
61+
62+
63+
# TODO maybe remove those functions and put the check to the long
64+
# if-constructs above?
65+
def _is_estimator(o):
66+
# TODO @amueller should one rather check whether this is a subclass of
67+
# BaseEstimator?
68+
#return (hasattr(o, 'fit') and hasattr(o, 'predict') and
69+
# hasattr(o, 'get_params') and hasattr(o, 'set_params'))
70+
return isinstance(o, sklearn.base.BaseEstimator)
71+
72+
73+
def _is_transformer(o):
74+
# TODO @amueller should one rather check whether this is a subclass of
75+
# BaseTransformer?
76+
return (hasattr(o, 'fit') and hasattr(o, 'transform') and
77+
hasattr(o, 'get_params') and hasattr(o, 'set_params'))
78+
79+
80+
def _is_cross_validator(o):
81+
return isinstance(o, sklearn.model_selection.BaseCrossValidator)
82+
83+
84+
def deserialize_object(o):
85+
if isinstance(o, dict):
86+
if 'oml:name' in o and 'oml:description' in o:
87+
rval = deserialize_model(o)
88+
elif 'oml:serialized_object' in o:
89+
serialized_type = o['oml:serialized_object']
90+
value = o['value']
91+
if serialized_type == 'type':
92+
rval = deserialize_type(value)
93+
elif serialized_type == 'rv_frozen':
94+
rval = deserialize_rv_frozen(value)
95+
elif serialized_type == 'function':
96+
rval = deserialize_function(value)
97+
else:
98+
raise ValueError('Cannot deserialize %s' % serialized_type)
99+
else:
100+
rval = {deserialize_object(key): deserialize_object(value)
101+
for key, value in o.items()}
102+
elif isinstance(o, (list, tuple)):
103+
rval = [deserialize_object(element) for element in o]
104+
if isinstance(o, tuple):
105+
rval = tuple(rval)
106+
elif o is None:
107+
rval = None
108+
elif isinstance(o, six.string_types):
109+
rval = o
110+
elif isinstance(o, int):
111+
rval = o
112+
elif isinstance(o, float):
113+
rval = o
114+
elif isinstance(o, OpenMLFlow):
115+
rval = o.model
116+
else:
117+
raise TypeError(o)
118+
assert o is None or rval is not None
119+
120+
return rval
121+
122+
123+
def serialize_model(model):
124+
sub_components = []
125+
parameters = []
126+
127+
model_parameters = model.get_params()
128+
129+
for k, v in sorted(model_parameters.items(), key=lambda t: t[0]):
130+
rval = serialize_object(v)
131+
132+
if isinstance(rval, (list, tuple)):
133+
# Steps in a pipeline or feature union
134+
for identifier, sub_component in rval:
135+
sub_component = OrderedDict((('oml:identifier', 'step__' + identifier),
136+
('oml:flow', sub_component)))
137+
sub_components.append(sub_component)
138+
param_dict = OrderedDict()
139+
param_dict['oml:name'] = k
140+
param_dict['oml:default_value'] = rval
141+
parameters.append(param_dict)
142+
elif isinstance(rval, OpenMLFlow):
143+
# Since serialize_object can return a Flow, we need to check
144+
# whether that flow represents a hyperparameter value (or is a
145+
# flow which was created because of a pipeline or e feature union)
146+
model_parameters = signature(model.__init__)
147+
if k not in model_parameters.parameters:
148+
continue
149+
150+
# A subcomponent, for example the base model in AdaBoostClassifier
151+
identifier = rval.name
152+
sub_component = OrderedDict((('oml:identifier', identifier),
153+
('oml:flow', rval)))
154+
sub_components.append(sub_component)
155+
param_dict = OrderedDict()
156+
param_dict['oml:name'] = k
157+
param_dict['oml:default_value'] = rval
158+
parameters.append(param_dict)
159+
else:
160+
# Since Pipeline and FeatureUnion also return estimators and
161+
# transformers in the 'steps' list with get_params(), we must
162+
# add them as a component, but not as a parameter of the
163+
# flow. The next if makes sure that we only add parameters
164+
# for the first case.
165+
model_parameters = signature(model.__init__)
166+
if k not in model_parameters.parameters:
167+
continue
168+
169+
# a regular hyperparameter
170+
param_dict = OrderedDict()
171+
param_dict['oml:name'] = k
172+
if not (hasattr(rval, '__len__') and len(rval) == 0):
173+
param_dict['oml:default_value'] = rval
174+
parameters.append(param_dict)
175+
176+
name = model.__module__ + "." + model.__class__.__name__
177+
sub_components_names = ",".join(
178+
[sub_component['oml:flow'].name
179+
for sub_component in sub_components])
180+
if sub_components_names:
181+
name = '%s(%s)' % (name, sub_components_names)
182+
if len(name) > MAXIMAL_FLOW_LENGTH:
183+
raise OpenMLRestrictionViolated('Flow name must not be longer '
184+
+ 'than %d characters!' % MAXIMAL_FLOW_LENGTH)
185+
186+
flow = OpenMLFlow(model=model, description='Automatically created '
187+
'sub-component.',
188+
parameters=parameters, components=sub_components)
189+
# TODO add name to the constructor
190+
flow.name = name
191+
192+
193+
return flow
194+
195+
196+
def deserialize_model(flow):
197+
# TODO remove potential test sentinel during testing!
198+
model_name = flow.name
199+
# Remove everything after the first bracket
200+
pos = model_name.find('(')
201+
if pos >= 0:
202+
model_name = model_name[:pos]
203+
204+
parameters = flow.parameters
205+
parameter_dict = {}
206+
207+
for parameter in parameters:
208+
name = parameter['oml:name']
209+
value = parameter.get('oml:default_value', None)
210+
211+
rval = deserialize_object(value)
212+
parameter_dict[name] = rval
213+
214+
module_name = model_name.rsplit('.', 1)
215+
try:
216+
model_class = getattr(importlib.import_module(module_name[0]),
217+
module_name[1])
218+
except:
219+
warnings.warn('Cannot create model %s for flow.' % model_name)
220+
return None
221+
222+
return model_class(**parameter_dict)
223+
224+
225+
def serialize_type(o):
226+
mapping = {float: 'float',
227+
np.float: 'np.float',
228+
np.float32: 'np.float32',
229+
np.float64: 'np.float64',
230+
int: 'int',
231+
np.int: 'np.int',
232+
np.int32: 'np.int32',
233+
np.int64: 'np.int64'}
234+
return {'oml:serialized_object': 'type',
235+
'value': mapping[o]}
236+
237+
238+
def deserialize_type(o):
239+
mapping = {'float': float,
240+
'np.float': np.float,
241+
'np.float32': np.float32,
242+
'np.float64': np.float64,
243+
'int': int,
244+
'np.int': np.int,
245+
'np.int32': np.int32,
246+
'np.int64': np.int64}
247+
return mapping[o]
248+
249+
250+
def serialize_rv_frozen(o):
251+
args = o.args
252+
kwds = o.kwds
253+
a = o.a
254+
b = o.b
255+
dist = o.dist.__class__.__module__ + '.' + o.dist.__class__.__name__
256+
return {'oml:serialized_object': 'rv_frozen',
257+
'value': {'dist': dist, 'a': a, 'b': b, 'args': args, 'kwds': kwds}}
258+
259+
260+
def deserialize_rv_frozen(o):
261+
args = o['args']
262+
kwds = o['kwds']
263+
a = o['a']
264+
b = o['b']
265+
dist_name = o['dist']
266+
267+
module_name = dist_name.rsplit('.', 1)
268+
try:
269+
model_class = getattr(importlib.import_module(module_name[0]),
270+
module_name[1])
271+
except:
272+
warnings.warn('Cannot create model %s for flow.' % dist_name)
273+
return None
274+
275+
dist = scipy.stats.distributions.rv_frozen(model_class(), *args, **kwds)
276+
dist.a = a
277+
dist.b = b
278+
279+
return dist
280+
281+
282+
def serialize_function(o):
283+
name = o.__module__ + '.' + o.__name__
284+
return {'oml:serialized_object': 'function',
285+
'value': name}
286+
287+
288+
def deserialize_function(name):
289+
module_name = name.rsplit('.', 1)
290+
try:
291+
model_class = getattr(importlib.import_module(module_name[0]),
292+
module_name[1])
293+
except Exception as e:
294+
warnings.warn('Cannot load function %s due to %s.' % (name, e))
295+
return None
296+
return model_class
297+
298+
299+
# This produces a flow, thus it does not need a deserialize. It cannot be fed
300+
# to serialize_model() because cross-validators do not have get_params().
301+
def serialize_cross_validator(o):
302+
parameters = []
303+
304+
# XXX this is copied from sklearn.model_selection._split
305+
cls = o.__class__
306+
init = getattr(cls.__init__, 'deprecated_original', cls.__init__)
307+
# Ignore varargs, kw and default values and pop self
308+
init_signature = signature(init)
309+
# Consider the constructor parameters excluding 'self'
310+
if init is object.__init__:
311+
args = []
312+
else:
313+
args = sorted([p.name for p in init_signature.parameters.values()
314+
if p.name != 'self' and p.kind != p.VAR_KEYWORD])
315+
316+
for key in args:
317+
# We need deprecation warnings to always be on in order to
318+
# catch deprecated param values.
319+
# This is set in utils/__init__.py but it gets overwritten
320+
# when running under python3 somehow.
321+
warnings.simplefilter("always", DeprecationWarning)
322+
try:
323+
with warnings.catch_warnings(record=True) as w:
324+
value = getattr(o, key, None)
325+
if len(w) and w[0].category == DeprecationWarning:
326+
# if the parameter is deprecated, don't show it
327+
continue
328+
finally:
329+
warnings.filters.pop(0)
330+
331+
param_dict = OrderedDict()
332+
param_dict['oml:name'] = key
333+
if not (hasattr(value, '__len__') and len(value) == 0):
334+
param_dict['oml:default_value'] = value
335+
parameters.append(param_dict)
336+
337+
# Create a flow
338+
name = o.__module__ + "." + o.__class__.__name__
339+
340+
flow = OpenMLFlow(model=o, description='Automatically created '
341+
'sub-component.',
342+
parameters=parameters, components=[])
343+
# TODO add name to the constructor
344+
flow.name = name
345+
346+
return flow

0 commit comments

Comments
 (0)