Skip to content

Commit e42ca51

Browse files
committed
FIX issue #77
1 parent 9518a13 commit e42ca51

6 files changed

Lines changed: 164 additions & 17 deletions

File tree

openml/flows/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .flow import OpenMLFlow
22
from .sklearn_converter import sklearn_to_flow, flow_to_sklearn
3-
from .functions import get_flow
3+
from .functions import get_flow, list_flows
44

5-
__all__ = ['OpenMLFlow', 'create_flow_from_model', 'get_flow',
5+
__all__ = ['OpenMLFlow', 'create_flow_from_model', 'get_flow', 'list_flows',
66
'sklearn_to_flow', 'flow_to_sklearn']

openml/flows/functions.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def get_flow(flow_id):
1212
flow_id : int
1313
The OpenML flow id.
1414
"""
15+
# TODO add caching here!
1516
try:
1617
flow_id = int(flow_id)
1718
except:
@@ -25,4 +26,69 @@ def get_flow(flow_id):
2526
if 'sklearn' in flow.external_version:
2627
flow.model = flow_to_sklearn(flow)
2728

28-
return flow
29+
return flow
30+
31+
32+
def list_flows(offset=None, size=None, tag=None):
33+
"""Return a list of all flows which are on OpenML.
34+
35+
Parameters
36+
----------
37+
offset : int, optional
38+
the number of flows to skip, starting from the first
39+
size : int, optional
40+
the maximum number of flows to return
41+
tag : str, optional
42+
the tag to include
43+
44+
Returns
45+
-------
46+
flows : dict
47+
A mapping from flow_id to a dict giving a brief overview of the
48+
respective flow.
49+
50+
Every flow is represented by a dictionary containing
51+
the following information:
52+
- flow id
53+
- full name
54+
- name
55+
- version
56+
- external version
57+
- uploader
58+
"""
59+
api_call = "flow/list"
60+
if offset is not None:
61+
api_call += "/offset/%d" % int(offset)
62+
63+
if size is not None:
64+
api_call += "/limit/%d" % int(size)
65+
66+
if tag is not None:
67+
api_call += "/tag/%s" % tag
68+
69+
return _list_datasets(api_call)
70+
71+
72+
def _list_datasets(api_call):
73+
# TODO add proper error handling here!
74+
return_code, xml_string = _perform_api_call(api_call)
75+
flows_dict = xmltodict.parse(xml_string)
76+
77+
# Minimalistic check if the XML is useful
78+
assert type(flows_dict['oml:flows']['oml:flow']) == list, \
79+
type(flows_dict['oml:flows'])
80+
assert flows_dict['oml:flows']['@xmlns:oml'] == \
81+
'http://openml.org/openml', flows_dict['oml:flows']['@xmlns:oml']
82+
83+
flows = dict()
84+
for flow_ in flows_dict['oml:flows']['oml:flow']:
85+
fid = int(flow_['oml:id'])
86+
flow = {'id': fid,
87+
'full_name': flow_['oml:full_name'],
88+
'name': flow_['oml:name'],
89+
'version': flow_['oml:version'],
90+
'external_version': flow_['oml:external_version'],
91+
'uploader': flow_['oml:uploader']}
92+
flows[fid] = flow
93+
94+
return flows

openml/flows/sklearn_converter.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Convert scikit-learn estimators into an OpenMLFlows and vice versa."""
22

33
from collections import OrderedDict
4+
from distutils.version import LooseVersion
45
import importlib
56
import inspect
67
import json
78
import json.decoder
9+
import re
810
import six
911
import warnings
1012
import sys
@@ -25,6 +27,10 @@
2527
JSONDecodeError = ValueError
2628

2729

30+
DEPENDENCIES_PATTERN = re.compile(
31+
'^(?P<name>[\w\-]+)((?P<operation>==|>=|>)(?P<version>(\d+\.)?(\d+\.)?(\d+)))?$')
32+
33+
2834
def sklearn_to_flow(o):
2935

3036
if _is_estimator(o):
@@ -174,6 +180,10 @@ def _serialize_model(model):
174180
# Get the external versions of all sub-components
175181
external_version = _get_external_version_string(model, sub_components)
176182

183+
dependencies = [_format_external_version('sklearn', sklearn.__version__),
184+
'numpy>=1.6.1', 'scipy>=0.9']
185+
dependencies = '\n'.join(dependencies)
186+
177187
flow = OpenMLFlow(name=name,
178188
class_name=class_name,
179189
description='Automatically created sub-component.',
@@ -185,7 +195,7 @@ def _serialize_model(model):
185195
tags=[],
186196
language='English',
187197
# TODO fill in dependencies!
188-
dependencies=None)
198+
dependencies=dependencies)
189199

190200
return flow
191201

@@ -317,6 +327,7 @@ def _extract_information_from_model(model):
317327
def _deserialize_model(flow, **kwargs):
318328

319329
model_name = flow.class_name
330+
_check_dependencies(flow.dependencies)
320331

321332
parameters = flow.parameters
322333
components = flow.components
@@ -352,6 +363,33 @@ def _deserialize_model(flow, **kwargs):
352363
return model_class(**parameter_dict)
353364

354365

366+
def _check_dependencies(dependencies):
367+
dependencies = dependencies.split('\n')
368+
for dependency_string in dependencies:
369+
match = DEPENDENCIES_PATTERN.match(dependency_string)
370+
dependency_name = match.group('name')
371+
operation = match.group('operation')
372+
version = match.group('version')
373+
374+
module = importlib.import_module(dependency_name)
375+
required_version = LooseVersion(version)
376+
installed_version = LooseVersion(module.__version__)
377+
378+
if operation == '==':
379+
check = required_version == installed_version
380+
elif operation == '>':
381+
check = installed_version > required_version
382+
elif operation == '>=':
383+
check = installed_version > required_version or \
384+
installed_version == required_version
385+
else:
386+
raise NotImplementedError(
387+
'operation \'%s\' is not supported' % operation)
388+
if not check:
389+
raise ValueError('Trying to deserialize a model with dependency '
390+
'%s not satisfied.' % dependency_string)
391+
392+
355393
def serialize_type(o):
356394
mapping = {float: 'float',
357395
np.float: 'np.float',

tests/test_flows/test_flow.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,6 @@ def get_sentinel():
6969

7070
class TestFlow(TestBase):
7171

72-
@unittest.skip('The method which is tested by this function doesnt exist')
73-
def test_download_flow_list(self):
74-
def check_flow(flow):
75-
self.assertIsInstance(flow, dict)
76-
self.assertEqual(len(flow), 6)
77-
78-
flows = openml.flows.get_flow_list()
79-
self.assertGreaterEqual(len(flows), 1448)
80-
for flow in flows:
81-
check_flow(flow)
82-
8372
def test_get_flow(self):
8473
# We need to use the production server here because 4024 is not the test
8574
# server
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import unittest
2+
3+
import openml
4+
from openml.util import is_string
5+
6+
7+
class TestFlowFunctions(unittest.TestCase):
8+
def _check_flow(self, flow):
9+
self.assertEqual(type(flow), dict)
10+
self.assertEqual(len(flow), 6)
11+
self.assertIsInstance(flow['id'], int)
12+
self.assertTrue(is_string(flow['name']))
13+
self.assertTrue(is_string(flow['full_name']))
14+
self.assertTrue(is_string(flow['version']))
15+
# There are some runs on openml.org that can have an empty external
16+
# version
17+
self.assertTrue(is_string(flow['external_version']) or
18+
flow['external_version'] is None)
19+
20+
def test_list_datasets(self):
21+
# We can only perform a smoke test here because we test on dynamic
22+
# data from the internet...
23+
flows = openml.flows.list_flows()
24+
# 3000 as the number of datasets on openml.org
25+
self.assertGreaterEqual(len(flows), 3000)
26+
for fid in flows:
27+
self._check_flow(flows[fid])
28+
29+
def test_list_datasets_by_tag(self):
30+
flows = openml.flows.list_flows(tag='weka')
31+
self.assertGreaterEqual(len(flows), 5)
32+
for did in flows:
33+
self._check_flow(flows[did])
34+
35+
def test_list_datasets_paginate(self):
36+
size = 10
37+
max = 100
38+
for i in range(0, max, size):
39+
flows = openml.flows.list_flows(offset=i, size=size)
40+
self.assertGreaterEqual(size, len(flows))
41+
for did in flows:
42+
self._check_flow(flows[did])

tests/test_flows/test_sklearn.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import sys
55
import unittest
6+
import unittest.mock
67

78
import numpy as np
89
import scipy.optimize
@@ -19,8 +20,9 @@
1920
import sklearn.preprocessing
2021
import sklearn.tree
2122

23+
import openml
2224
from openml.flows import OpenMLFlow, sklearn_to_flow, flow_to_sklearn
23-
from openml.flows.sklearn_converter import _format_external_version
25+
from openml.flows.sklearn_converter import _format_external_version, _check_dependencies
2426

2527
this_directory = os.path.dirname(os.path.abspath(__file__))
2628
sys.path.append(this_directory)
@@ -47,13 +49,15 @@ def setUp(self):
4749
self.X = iris.data
4850
self.y = iris.target
4951

50-
def test_serialize_model(self):
52+
@unittest.mock.patch('openml.flows.sklearn_converter._check_dependencies')
53+
def test_serialize_model(self, check_dependencies_mock):
5154
model = sklearn.tree.DecisionTreeClassifier(criterion='entropy',
5255
max_features='auto',
5356
max_leaf_nodes=2000)
5457

5558
fixture_name = 'sklearn.tree.tree.DecisionTreeClassifier'
5659
fixture_description = 'Automatically created sub-component.'
60+
version_fixture = 'sklearn==%s\nnumpy>=1.6.1\nscipy>=0.9' % sklearn.__version__
5761
fixture_parameters = \
5862
OrderedDict((('class_weight', 'null'),
5963
('criterion', '"entropy"'),
@@ -74,6 +78,7 @@ def test_serialize_model(self):
7478
self.assertEqual(serialization.class_name, fixture_name)
7579
self.assertEqual(serialization.description, fixture_description)
7680
self.assertEqual(serialization.parameters, fixture_parameters)
81+
self.assertEqual(serialization.dependencies, version_fixture)
7782

7883
new_model = flow_to_sklearn(serialization)
7984

@@ -83,6 +88,8 @@ def test_serialize_model(self):
8388
self.assertEqual(new_model.get_params(), model.get_params())
8489
new_model.fit(self.X, self.y)
8590

91+
self.assertEqual(check_dependencies_mock.call_count, 1)
92+
8693
def test_serialize_model_with_subcomponent(self):
8794
model = sklearn.ensemble.AdaBoostClassifier(
8895
n_estimators=100, base_estimator=sklearn.tree.DecisionTreeClassifier())
@@ -508,3 +515,8 @@ def test_subflow_version_propagated(self):
508515
self.assertEqual(flow.external_version, '%s,%s' % (
509516
_format_external_version('sklearn', sklearn.__version__),
510517
_format_external_version('tests', '0.1')))
518+
519+
def test_check_dependencies(self):
520+
dependencies = ['sklearn==0.1', 'sklearn>=99.99.99', 'sklearn>99.99.99']
521+
for dependency in dependencies:
522+
self.assertRaises(ValueError, _check_dependencies, dependency)

0 commit comments

Comments
 (0)