Skip to content

Commit 28c57ff

Browse files
committed
ADD class_name and custom_name fields
1 parent 68dbd0f commit 28c57ff

2 files changed

Lines changed: 19 additions & 13 deletions

File tree

openml/flows/flow.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ class OpenMLFlow(object):
5252
A list of dependencies necessary to run the flow. This field should
5353
contain all libraries the flow depends on. To allow reproducibility
5454
it should also specify the exact version numbers.
55+
class_name : str
56+
The development language name of the class which is described by this
57+
flow.
58+
custom_name : str
59+
Custom name of the flow given by the owner.
5560
binary_url : str, optional
5661
Url from which the binary can be downloaded. Added by the server.
5762
Ignored when uploaded manually. Will not be used by the python API
@@ -75,7 +80,8 @@ class OpenMLFlow(object):
7580

7681
def __init__(self, name, description, model, components, parameters,
7782
parameters_meta_info, external_version, tags, language,
78-
dependencies, binary_url=None, binary_format=None,
83+
dependencies, class_name=None, custom_name=None,
84+
binary_url=None, binary_format=None,
7985
binary_md5=None, uploader=None, upload_date=None,
8086
flow_id=None, version=None):
8187
self.name = name
@@ -93,6 +99,7 @@ def __init__(self, name, description, model, components, parameters,
9399
self.components = components
94100
self.parameters = parameters
95101
self.parameters_meta_info = parameters_meta_info
102+
self.class_name = class_name
96103

97104
keys_parameters = set(parameters.keys())
98105
keys_parameters_meta_info = set(parameters_meta_info.keys())
@@ -110,6 +117,7 @@ def __init__(self, name, description, model, components, parameters,
110117
self.external_version = external_version
111118
self.uploader = uploader
112119

120+
self.custom_name = custom_name
113121
self.tags = tags if tags is not None else []
114122
self.binary_url = binary_url
115123
self.binary_format = binary_format
@@ -164,9 +172,9 @@ def _to_dict(self):
164172
if getattr(self, required) is None:
165173
raise ValueError("self.{} is required but None".format(
166174
required))
167-
for attribute in ["uploader", "name", "version", "external_version",
168-
"description", "upload_date", "language",
169-
"dependencies"]:
175+
for attribute in ["uploader", "name", "custom_name", "class_name",
176+
"version", "external_version", "description",
177+
"upload_date", "language", "dependencies"]:
170178
_add_if_nonempty(flow_dict, 'oml:{}'.format(attribute),
171179
getattr(self, attribute))
172180

@@ -248,7 +256,7 @@ def _from_dict(cls, xml_dict):
248256
# non-mandatory parts in the xml file
249257
for key in ['uploader', 'description', 'upload_date', 'language',
250258
'dependencies', 'version', 'binary_url', 'binary_format',
251-
'binary_md5']:
259+
'binary_md5', 'class_name', 'custom_name']:
252260
arguments[key] = dic.get("oml:" + key)
253261

254262
# has to be converted to an int if present and cannot parsed in the

openml/flows/sklearn_converter.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def _serialize_model(model):
244244
# Create a flow name, which contains all components in brackets, for
245245
# example RandomizedSearchCV(Pipeline(StandardScaler,AdaBoostClassifier(DecisionTreeClassifier)),StandardScaler,AdaBoostClassifier(DecisionTreeClassifier))
246246
# TODO the name above is apparently wrong, I need to test and check this
247-
name = model.__module__ + "." + model.__class__.__name__
247+
class_name = model.__module__ + "." + model.__class__.__name__
248248

249249
# will be part of the name (in brackets)
250250
sub_components_names = ""
@@ -256,10 +256,13 @@ def _serialize_model(model):
256256

257257
if sub_components_names:
258258
# slice operation on string in order to get rid of leading comma
259-
name = '%s(%s)' % (name, sub_components_names[1:])
259+
name = '%s(%s)' % (class_name, sub_components_names[1:])
260+
else:
261+
name = class_name
260262

261263
external_version = _get_external_version_info()
262264
flow = OpenMLFlow(name=name,
265+
class_name=class_name,
263266
description='Automatically created sub-component.',
264267
model=model,
265268
components=sub_components,
@@ -276,12 +279,7 @@ def _serialize_model(model):
276279

277280
def _deserialize_model(flow, **kwargs):
278281

279-
model_name = flow.name
280-
# Remove everything after the first bracket, it is not necessary for
281-
# creating the current flow
282-
pos = model_name.find('(')
283-
if pos >= 0:
284-
model_name = model_name[:pos]
282+
model_name = flow.class_name
285283

286284
parameters = flow.parameters
287285
components = flow.components

0 commit comments

Comments
 (0)