|
1 | 1 | from collections import OrderedDict |
2 | 2 | import xmltodict |
3 | | -import sklearn |
4 | 3 |
|
5 | 4 | from .._api_calls import _perform_api_call |
6 | 5 | from .functions import _check_flow_exists |
@@ -28,30 +27,31 @@ class OpenMLFlow(object): |
28 | 27 |
|
29 | 28 |
|
30 | 29 | """ |
31 | | - def __init__(self, model, flow_id=None, uploader=None, |
32 | | - description=None, creator=None, components=None, |
33 | | - parameters=None, contributor=None, tag=None): |
34 | | - self.flow_id = flow_id |
35 | | - self.upoader = uploader |
| 30 | + def __init__(self, name, description=None, model=None, components=None, |
| 31 | + parameters=None, external_version=None, creator=None, |
| 32 | + uploader=None, tag=None, flow_id=None): |
| 33 | + self.name = name |
36 | 34 | self.description = description |
37 | | - self.creator = creator |
38 | | - self.tag = tag |
39 | 35 | self.model = model |
40 | 36 |
|
41 | | - # TODO update these - the sklearn transformation class should be able |
42 | | - # to do this! |
43 | | - self.source = "FIXME DEFINE PYTHON FLOW" |
44 | | - self.name = (model.__module__ + "." + |
45 | | - model.__class__.__name__) |
46 | | - self.external_version = 'sklearn_' + sklearn.__version__ |
47 | | - |
48 | 37 | if components is None: |
49 | | - components = [] |
| 38 | + components = OrderedDict() |
| 39 | + elif not isinstance(components, OrderedDict): |
| 40 | + raise TypeError('Components must be of type OrderedDict, but are %s.' % |
| 41 | + type(components)) |
50 | 42 | self.components = components |
51 | 43 | if parameters is None: |
52 | | - parameters = [] |
| 44 | + parameters = OrderedDict() |
| 45 | + elif not isinstance(parameters, OrderedDict): |
| 46 | + raise TypeError('Parameters must be of type OrderedDict, but are %s.' % |
| 47 | + type(parameters)) |
53 | 48 | self.parameters = parameters |
54 | 49 |
|
| 50 | + self.external_version = external_version |
| 51 | + self.creator = creator |
| 52 | + self.upoader = uploader |
| 53 | + self.tag = tag |
| 54 | + self.flow_id = flow_id |
55 | 55 |
|
56 | 56 | def _generate_flow_xml(self): |
57 | 57 | """Generate xml representation of self for upload to server. |
@@ -133,3 +133,12 @@ def _get_name(self): |
133 | 133 | return self.name |
134 | 134 |
|
135 | 135 |
|
| 136 | +def create_flow_from_model(model, converter, description=None): |
| 137 | + flow = converter.serialize_object(model) |
| 138 | + if not isinstance(flow, OpenMLFlow): |
| 139 | + raise ValueError('Converter %s did return %s, not OpenMLFlow!' % |
| 140 | + (str(converter), type(flow))) |
| 141 | + if description is not None: |
| 142 | + flow.description = description |
| 143 | + |
| 144 | + return flow |
0 commit comments