@@ -47,23 +47,6 @@ def test_check_any_counterfactuals_computed(
4747 cf_examples_arr = [cf_example_has_cf , cf_example_no_cf ]
4848 exp ._check_any_counterfactuals_computed (cf_examples_arr = cf_examples_arr )
4949
50- @pytest .mark .parametrize ("desired_class" , [1 ])
51- def test_zero_totalcfs (
52- self , desired_class , method , sample_custom_query_1 ,
53- custom_public_data_interface ,
54- sklearn_binary_classification_model_interface
55- ):
56- exp = dice_ml .Dice (
57- custom_public_data_interface ,
58- sklearn_binary_classification_model_interface ,
59- method = method )
60-
61- with pytest .raises (UserConfigValidationException ):
62- exp .generate_counterfactuals (
63- query_instances = [sample_custom_query_1 ],
64- total_CFs = 0 ,
65- desired_class = desired_class )
66-
6750 @pytest .mark .parametrize ("desired_class" , [1 ])
6851 def test_local_feature_importance (
6952 self , desired_class , method ,
@@ -128,109 +111,6 @@ def test_global_feature_importance(
128111
129112 self ._verify_feature_importance (global_importance .summary_importance )
130113
131- @pytest .mark .parametrize ("desired_class" , [1 ])
132- def test_global_feature_importance_error_conditions_with_insufficient_query_points (
133- self , desired_class , method ,
134- sample_custom_query_1 ,
135- custom_public_data_interface ,
136- sklearn_binary_classification_model_interface ):
137- exp = dice_ml .Dice (
138- custom_public_data_interface ,
139- sklearn_binary_classification_model_interface ,
140- method = method )
141-
142- cf_explanations = exp .generate_counterfactuals (
143- query_instances = sample_custom_query_1 ,
144- total_CFs = 15 ,
145- desired_class = desired_class )
146-
147- with pytest .raises (
148- UserConfigValidationException ,
149- match = "The number of points for which counterfactuals generated should be "
150- "greater than or equal to 10 "
151- "to compute global feature importance" ):
152- exp .global_feature_importance (
153- query_instances = None ,
154- cf_examples_list = cf_explanations .cf_examples_list )
155-
156- with pytest .raises (
157- UserConfigValidationException ,
158- match = "The number of query instances should be greater than or equal to 10 "
159- "to compute global feature importance over all query points" ):
160- exp .global_feature_importance (
161- query_instances = sample_custom_query_1 ,
162- total_CFs = 15 ,
163- desired_class = desired_class )
164-
165- @pytest .mark .parametrize ("desired_class" , [1 ])
166- def test_global_feature_importance_error_conditions_with_insufficient_cfs_per_query_point (
167- self , desired_class , method ,
168- sample_custom_query_10 ,
169- custom_public_data_interface ,
170- sklearn_binary_classification_model_interface ):
171- exp = dice_ml .Dice (
172- custom_public_data_interface ,
173- sklearn_binary_classification_model_interface ,
174- method = method )
175-
176- cf_explanations = exp .generate_counterfactuals (
177- query_instances = sample_custom_query_10 ,
178- total_CFs = 1 ,
179- desired_class = desired_class )
180-
181- with pytest .raises (
182- UserConfigValidationException ,
183- match = "The number of counterfactuals generated per query instance should be "
184- "greater than or equal to 10 "
185- "to compute global feature importance over all query points" ):
186- exp .global_feature_importance (
187- query_instances = None ,
188- cf_examples_list = cf_explanations .cf_examples_list )
189-
190- with pytest .raises (
191- UserConfigValidationException ,
192- match = "The number of counterfactuals requested per query instance should be greater "
193- "than or equal to 10 "
194- "to compute global feature importance over all query points" ):
195- exp .global_feature_importance (
196- query_instances = sample_custom_query_10 ,
197- total_CFs = 1 ,
198- desired_class = desired_class )
199-
200- @pytest .mark .parametrize ("desired_class" , [1 ])
201- def test_local_feature_importance_error_conditions_with_insufficient_cfs_per_query_point (
202- self , desired_class , method ,
203- sample_custom_query_1 ,
204- custom_public_data_interface ,
205- sklearn_binary_classification_model_interface ):
206- exp = dice_ml .Dice (
207- custom_public_data_interface ,
208- sklearn_binary_classification_model_interface ,
209- method = method )
210-
211- cf_explanations = exp .generate_counterfactuals (
212- query_instances = sample_custom_query_1 ,
213- total_CFs = 1 ,
214- desired_class = desired_class )
215-
216- with pytest .raises (
217- UserConfigValidationException ,
218- match = "The number of counterfactuals generated per query instance should be "
219- "greater than or equal to 10 to compute feature importance for all query points" ):
220- exp .local_feature_importance (
221- query_instances = None ,
222- cf_examples_list = cf_explanations .cf_examples_list )
223-
224- with pytest .raises (
225- UserConfigValidationException ,
226- match = "The number of counterfactuals requested per "
227- "query instance should be greater than or equal to 10 "
228- "to compute feature importance for all query points" ):
229- exp .local_feature_importance (
230- query_instances = sample_custom_query_1 ,
231- total_CFs = 1 ,
232- desired_class = desired_class )
233-
234114 # @pytest.mark.parametrize("desired_class, binary_classification_exp_object_out_of_order",
235115 # [(1, 'random'), (1, 'genetic'), (1, 'kdtree')],
236116 # indirect=['binary_classification_exp_object_out_of_order'])
@@ -545,3 +425,180 @@ class TestExplainerBase:
545425 def test_instantiating_explainer_base (self , public_data_object ):
546426 with pytest .raises (TypeError ):
547427 ExplainerBase (data_interface = public_data_object )
428+
429+
430+ @pytest .mark .parametrize ("method" , ['random' , 'genetic' , 'kdtree' ])
431+ class TestExplainerBaseUserConfigValidations :
432+
433+ @pytest .mark .parametrize ('explainer_function' ,
434+ ['generate_counterfactuals' , 'local_feature_importance' ,
435+ 'feature_importance' , 'global_feature_importance' ])
436+ def test_generate_counterfactuals_user_config_validations (
437+ self , method , sample_custom_query_2 ,
438+ custom_public_data_interface ,
439+ sklearn_binary_classification_model_interface ,
440+ explainer_function ):
441+ exp = dice_ml .Dice (
442+ custom_public_data_interface ,
443+ sklearn_binary_classification_model_interface ,
444+ method = method )
445+
446+ explainer_function = getattr (exp , explainer_function )
447+ with pytest .raises (
448+ UserConfigValidationException ,
449+ match = r"The number of counterfactuals generated per query instance \(total_CFs\) "
450+ "should be a positive integer." ):
451+ explainer_function (query_instances = sample_custom_query_2 ,
452+ total_CFs = - 10 , desired_class = 'opposite' )
453+
454+ with pytest .raises (
455+ UserConfigValidationException ,
456+ match = r"The number of counterfactuals generated per query instance \(total_CFs\) "
457+ "should be a positive integer." ):
458+ explainer_function (
459+ query_instances = sample_custom_query_2 ,
460+ total_CFs = 0 ,
461+ desired_class = "opposite" )
462+
463+ with pytest .raises (
464+ UserConfigValidationException ,
465+ match = r"The posthoc_sparsity_algorithm should be linear or binary and not random" ):
466+ explainer_function (query_instances = sample_custom_query_2 ,
467+ total_CFs = 10 ,
468+ posthoc_sparsity_algorithm = 'random' )
469+
470+ with pytest .raises (
471+ UserConfigValidationException ,
472+ match = r"The posthoc_sparsity_algorithm should be linear or binary and not random" ):
473+ explainer_function (query_instances = sample_custom_query_2 ,
474+ total_CFs = 10 ,
475+ posthoc_sparsity_algorithm = 'random' )
476+
477+ with pytest .raises (
478+ UserConfigValidationException ,
479+ match = r'The stopping_threshold should lie between 0.0 and 1.0' ):
480+ explainer_function (query_instances = sample_custom_query_2 ,
481+ total_CFs = 10 ,
482+ stopping_threshold = - 10.0 )
483+
484+ with pytest .raises (
485+ UserConfigValidationException ,
486+ match = r'The posthoc_sparsity_param should lie between 0.0 and 1.0' ):
487+ explainer_function (query_instances = sample_custom_query_2 ,
488+ total_CFs = 10 ,
489+ posthoc_sparsity_param = - 10.0 )
490+
491+ with pytest .raises (
492+ UserConfigValidationException ,
493+ match = r'The desired_range parameter should not be set for classification task' ):
494+ explainer_function (query_instances = sample_custom_query_2 ,
495+ total_CFs = 10 , desired_range = [0 , 10 ])
496+
497+ @pytest .mark .parametrize ('explainer_function' ,
498+ ['generate_counterfactuals' , 'local_feature_importance' ,
499+ 'feature_importance' , 'global_feature_importance' ])
500+ def test_generate_counterfactuals_user_config_validations_regression (
501+ self , regression_exp_object , sample_custom_query_1 ,
502+ method , explainer_function ):
503+ explainer_function = getattr (regression_exp_object , explainer_function )
504+ with pytest .raises (
505+ UserConfigValidationException ,
506+ match = r'The desired_range parameter should be set for regression task' ):
507+ explainer_function (query_instances = sample_custom_query_1 ,
508+ total_CFs = 10 )
509+
510+ def test_global_feature_importance_error_conditions_with_insufficient_query_points (
511+ self , method ,
512+ sample_custom_query_1 ,
513+ custom_public_data_interface ,
514+ sklearn_binary_classification_model_interface ):
515+ exp = dice_ml .Dice (
516+ custom_public_data_interface ,
517+ sklearn_binary_classification_model_interface ,
518+ method = method )
519+
520+ cf_explanations = exp .generate_counterfactuals (
521+ query_instances = sample_custom_query_1 ,
522+ total_CFs = 15 )
523+
524+ with pytest .raises (
525+ UserConfigValidationException ,
526+ match = "The number of points for which counterfactuals generated should be "
527+ "greater than or equal to 10 "
528+ "to compute global feature importance" ):
529+ exp .global_feature_importance (
530+ query_instances = None ,
531+ cf_examples_list = cf_explanations .cf_examples_list )
532+
533+ with pytest .raises (
534+ UserConfigValidationException ,
535+ match = "The number of query instances should be greater than or equal to 10 "
536+ "to compute global feature importance over all query points" ):
537+ exp .global_feature_importance (
538+ query_instances = sample_custom_query_1 ,
539+ total_CFs = 15 )
540+
541+ def test_global_feature_importance_error_conditions_with_insufficient_cfs_per_query_point (
542+ self , method ,
543+ sample_custom_query_10 ,
544+ custom_public_data_interface ,
545+ sklearn_binary_classification_model_interface ):
546+ exp = dice_ml .Dice (
547+ custom_public_data_interface ,
548+ sklearn_binary_classification_model_interface ,
549+ method = method )
550+
551+ cf_explanations = exp .generate_counterfactuals (
552+ query_instances = sample_custom_query_10 ,
553+ total_CFs = 1 )
554+
555+ with pytest .raises (
556+ UserConfigValidationException ,
557+ match = "The number of counterfactuals generated per query instance should be "
558+ "greater than or equal to 10 "
559+ "to compute global feature importance over all query points" ):
560+ exp .global_feature_importance (
561+ query_instances = None ,
562+ cf_examples_list = cf_explanations .cf_examples_list )
563+
564+ with pytest .raises (
565+ UserConfigValidationException ,
566+ match = "The number of counterfactuals requested per query instance should be greater "
567+ "than or equal to 10 "
568+ "to compute global feature importance over all query points" ):
569+ exp .global_feature_importance (
570+ query_instances = sample_custom_query_10 ,
571+ total_CFs = 1 )
572+
573+ def test_local_feature_importance_error_conditions_with_insufficient_cfs_per_query_point (
574+ self , method ,
575+ sample_custom_query_1 ,
576+ custom_public_data_interface ,
577+ sklearn_binary_classification_model_interface ):
578+ exp = dice_ml .Dice (
579+ custom_public_data_interface ,
580+ sklearn_binary_classification_model_interface ,
581+ method = method )
582+
583+ cf_explanations = exp .generate_counterfactuals (
584+ query_instances = sample_custom_query_1 ,
585+ total_CFs = 1 )
586+
587+ with pytest .raises (
588+ UserConfigValidationException ,
589+ match = "The number of counterfactuals generated per query instance should be "
590+ "greater than or equal to 10 to compute feature importance for all query points" ):
591+ exp .local_feature_importance (
592+ query_instances = None ,
593+ cf_examples_list = cf_explanations .cf_examples_list )
594+
595+ with pytest .raises (
596+ UserConfigValidationException ,
597+ match = "The number of counterfactuals requested per "
598+ "query instance should be greater than or equal to 10 "
599+ "to compute feature importance for all query points" ):
600+ exp .local_feature_importance (
601+ query_instances = sample_custom_query_1 ,
602+ total_CFs = 1 )
603+
604+ # class TestExplainerBaseDataValidations:
0 commit comments