|
3 | 3 | import os |
4 | 4 | import sys |
5 | 5 | import unittest |
| 6 | +import warnings |
6 | 7 |
|
7 | 8 | if sys.version_info[0] >= 3: |
8 | 9 | from unittest import mock |
|
24 | 25 | import sklearn.preprocessing |
25 | 26 | import sklearn.tree |
26 | 27 |
|
| 28 | +import openml |
27 | 29 | from openml.flows import OpenMLFlow, sklearn_to_flow, flow_to_sklearn |
28 | | - |
29 | 30 | from openml.flows.functions import assert_flows_equal |
30 | 31 | from openml.flows.sklearn_converter import _format_external_version, \ |
31 | 32 | _check_dependencies, _check_n_jobs |
@@ -63,7 +64,8 @@ def test_serialize_model(self, check_dependencies_mock): |
63 | 64 |
|
64 | 65 | fixture_name = 'sklearn.tree.tree.DecisionTreeClassifier' |
65 | 66 | fixture_description = 'Automatically created scikit-learn flow.' |
66 | | - version_fixture = 'sklearn==%s\nnumpy>=1.6.1\nscipy>=0.9' % sklearn.__version__ |
| 67 | + version_fixture = 'sklearn==%s\nnumpy>=1.6.1\nscipy>=0.9\nopenml==%s' \ |
| 68 | + '' % (sklearn.__version__, openml.__version__) |
67 | 69 | fixture_parameters = \ |
68 | 70 | OrderedDict((('class_weight', 'null'), |
69 | 71 | ('criterion', '"entropy"'), |
@@ -524,10 +526,18 @@ def test_subflow_version_propagated(self): |
524 | 526 | _format_external_version('sklearn', sklearn.__version__), |
525 | 527 | _format_external_version('tests', '0.1'))) |
526 | 528 |
|
527 | | - def test_check_dependencies(self): |
| 529 | + @mock.patch('warnings.warn') |
| 530 | + def test_check_dependencies(self, warnings_mock): |
528 | 531 | dependencies = ['sklearn==0.1', 'sklearn>=99.99.99', 'sklearn>99.99.99'] |
529 | 532 | for dependency in dependencies: |
530 | 533 | self.assertRaises(ValueError, _check_dependencies, dependency) |
| 534 | + dependency = 'openml==0.0.12345' |
| 535 | + _check_dependencies(dependency) |
| 536 | + self.assertEqual(warnings_mock.call_count, 1) |
| 537 | + self.assertEqual(warnings_mock.call_args[0][0], |
| 538 | + 'De-serializing a flow which was created with ' |
| 539 | + 'openml==%s, this is openml==%s.' % ( |
| 540 | + openml.__version__, '0.0.12345')) |
531 | 541 |
|
532 | 542 | def test_illegal_parameter_names(self): |
533 | 543 | # illegal name: estimators |
|
0 commit comments