33import unittest
44
55import numpy as np
6+ import sklearn .base
67import sklearn .datasets
78import scipy .stats
89import sklearn .decomposition
1819from openml .flows import OpenMLFlow
1920
2021
22+ class Model (sklearn .base .BaseEstimator ):
23+ def __init__ (self , boolean , integer , floating_point_value ):
24+ self .boolean = boolean
25+ self .integer = integer
26+ self .floating_point_value = floating_point_value
27+
28+
2129class TestSklearn (unittest .TestCase ):
2230
2331 def setUp (self ):
@@ -34,18 +42,18 @@ def test_serialize_model(self):
3442 fixture_name = 'sklearn.tree.tree.DecisionTreeClassifier'
3543 fixture_description = 'Automatically created sub-component.'
3644 fixture_parameters = \
37- OrderedDict ((('class_weight' , None ),
38- ('criterion' , 'entropy' ),
39- ('max_depth' , None ),
40- ('max_features' , 'auto' ),
45+ OrderedDict ((('class_weight' , 'null' ),
46+ ('criterion' , '" entropy" ' ),
47+ ('max_depth' , 'null' ),
48+ ('max_features' , '" auto" ' ),
4149 ('max_leaf_nodes' , '2000' ),
4250 ('min_impurity_split' , '1e-07' ),
4351 ('min_samples_leaf' , '1' ),
4452 ('min_samples_split' , '2' ),
4553 ('min_weight_fraction_leaf' , '0.0' ),
4654 ('presort' , 'false' ),
47- ('random_state' , None ),
48- ('splitter' , 'best' )))
55+ ('random_state' , 'null' ),
56+ ('splitter' , '" best" ' )))
4957
5058 serialization = self .converter .serialize_object (model )
5159
@@ -73,7 +81,7 @@ def test_serialize_model_with_subcomponent(self):
7381
7482 self .assertEqual (serialization .name , fixture_name )
7583 self .assertEqual (serialization .description , fixture_description )
76- self .assertEqual (serialization .parameters ['algorithm' ], 'SAMME.R' )
84+ self .assertEqual (serialization .parameters ['algorithm' ], '" SAMME.R" ' )
7785 self .assertIsInstance (serialization .parameters ['base_estimator' ], str )
7886 self .assertEqual (serialization .parameters ['learning_rate' ], '1.0' )
7987 self .assertEqual (serialization .parameters ['n_estimators' ], '100' )
@@ -294,11 +302,14 @@ def test_serialize_resampling(self):
294302 self .assertIsNot (deserialized , kfold )
295303
296304 def test_hypothetical_parameter_values (self ):
297- values = ['true' , '1' , '0.1' ]
298- for value in values :
299- serialized = self .converter .serialize_object (value )
300- deserialized = self .converter .deserialize_object (value )
301- self .assertEqual (deserialized , value )
305+ # Can only be checked inside a model
306+
307+ model = Model ('true' , '1' , '0.1' )
308+
309+ serialized = self .converter .serialize_object (model )
310+ deserialized = self .converter .deserialize_object (serialized )
311+ self .assertEqual (deserialized .get_params (), model .get_params ())
312+ self .assertIsNot (deserialized , model )
302313
303314
304315
0 commit comments