2525import sklearn .tree
2626
2727from openml .flows import OpenMLFlow , sklearn_to_flow , flow_to_sklearn
28+
2829from openml .flows .functions import assert_flows_equal
29- from openml .flows .sklearn_converter import _format_external_version , _check_dependencies
30+ from openml .flows .sklearn_converter import _format_external_version , \
31+ _check_dependencies , _check_n_jobs
3032from openml .exceptions import PyOpenMLError
3133
3234this_directory = os .path .dirname (os .path .abspath (__file__ ))
@@ -555,3 +557,36 @@ def test_illegal_parameter_names_featureunion(self):
555557 ('OneHotEncoder' , sklearn .preprocessing .OneHotEncoder (sparse = False , handle_unknown = 'ignore' ))
556558 ]
557559 self .assertRaises (ValueError , sklearn .pipeline .FeatureUnion , transformer_list = transformer_list )
560+
561+ def test_paralizable_check (self ):
562+ # using this model should pass the test (if param distribution is legal)
563+ singlecore_bagging = sklearn .ensemble .BaggingClassifier ()
564+ # using this model should return false (if param distribution is legal)
565+ multicore_bagging = sklearn .ensemble .BaggingClassifier (n_jobs = 5 )
566+ # using this param distribution should raise an exception
567+ illegal_param_dist = {"base__n_jobs" : [- 1 , 0 , 1 ] }
568+ # using this param distribution should not raise an exception
569+ legal_param_dist = {"base__max_depth" : [2 , 3 , 4 ]}
570+
571+ legal_models = [
572+ sklearn .ensemble .RandomForestClassifier (),
573+ sklearn .ensemble .RandomForestClassifier (n_jobs = 5 ),
574+ sklearn .ensemble .RandomForestClassifier (n_jobs = - 1 ),
575+ sklearn .pipeline .Pipeline (steps = [('bag' , sklearn .ensemble .BaggingClassifier (n_jobs = 1 ))]),
576+ sklearn .pipeline .Pipeline (steps = [('bag' , sklearn .ensemble .BaggingClassifier (n_jobs = 5 ))]),
577+ sklearn .pipeline .Pipeline (steps = [('bag' , sklearn .ensemble .BaggingClassifier (n_jobs = - 1 ))]),
578+ sklearn .model_selection .GridSearchCV (singlecore_bagging , legal_param_dist ),
579+ sklearn .model_selection .GridSearchCV (multicore_bagging , legal_param_dist )
580+ ]
581+ illegal_models = [
582+ sklearn .model_selection .GridSearchCV (singlecore_bagging , illegal_param_dist ),
583+ sklearn .model_selection .GridSearchCV (multicore_bagging , illegal_param_dist )
584+ ]
585+
586+ answers = [True , False , False , True , False , False , True , False ]
587+
588+ for i in range (len (legal_models )):
589+ self .assertTrue (_check_n_jobs (legal_models [i ]) == answers [i ])
590+
591+ for i in range (len (illegal_models )):
592+ self .assertRaises (PyOpenMLError , _check_n_jobs , illegal_models [i ])
0 commit comments