2525import sklearn .tree
2626
2727from openml .flows import OpenMLFlow , sklearn_to_flow , flow_to_sklearn
28- from openml .flows .sklearn_converter import _format_external_version , _check_dependencies
28+ from openml .flows .sklearn_converter import _format_external_version , \
29+ _check_dependencies , _check_n_jobs
2930from openml .exceptions import PyOpenMLError
3031
3132this_directory = os .path .dirname (os .path .abspath (__file__ ))
@@ -555,3 +556,36 @@ def test_illegal_parameter_names_featureunion(self):
555556 ('OneHotEncoder' , sklearn .preprocessing .OneHotEncoder (sparse = False , handle_unknown = 'ignore' ))
556557 ]
557558 self .assertRaises (ValueError , sklearn .pipeline .FeatureUnion , transformer_list = transformer_list )
559+
560+ def test_paralizable_check (self ):
561+ # using this model should pass the test (if param distribution is legal)
562+ singlecore_bagging = sklearn .ensemble .BaggingClassifier ()
563+ # using this model should return false (if param distribution is legal)
564+ multicore_bagging = sklearn .ensemble .BaggingClassifier (n_jobs = 5 )
565+ # using this param distribution should raise an exception
566+ illegal_param_dist = {"base__n_jobs" : [- 1 , 0 , 1 ] }
567+ # using this param distribution should not raise an exception
568+ legal_param_dist = {"base__max_depth" : [2 , 3 , 4 ]}
569+
570+ legal_models = [
571+ sklearn .ensemble .RandomForestClassifier (),
572+ sklearn .ensemble .RandomForestClassifier (n_jobs = 5 ),
573+ sklearn .ensemble .RandomForestClassifier (n_jobs = - 1 ),
574+ sklearn .pipeline .Pipeline (steps = [('bag' , sklearn .ensemble .BaggingClassifier (n_jobs = 1 ))]),
575+ sklearn .pipeline .Pipeline (steps = [('bag' , sklearn .ensemble .BaggingClassifier (n_jobs = 5 ))]),
576+ sklearn .pipeline .Pipeline (steps = [('bag' , sklearn .ensemble .BaggingClassifier (n_jobs = - 1 ))]),
577+ sklearn .model_selection .GridSearchCV (singlecore_bagging , legal_param_dist ),
578+ sklearn .model_selection .GridSearchCV (multicore_bagging , legal_param_dist )
579+ ]
580+ illegal_models = [
581+ sklearn .model_selection .GridSearchCV (singlecore_bagging , illegal_param_dist ),
582+ sklearn .model_selection .GridSearchCV (multicore_bagging , illegal_param_dist )
583+ ]
584+
585+ answers = [True , False , False , True , False , False , True , False ]
586+
587+ for i in range (len (legal_models )):
588+ self .assertTrue (_check_n_jobs (legal_models [i ]) == answers [i ])
589+
590+ for i in range (len (illegal_models )):
591+ self .assertRaises (PyOpenMLError , _check_n_jobs , illegal_models [i ])
0 commit comments