@@ -18,6 +18,17 @@ def KD_binary_classification_exp_object():
1818 return exp
1919
2020
21+ @pytest .fixture ()
22+ def KD_binary_vars_classification_exp_object (load_custom_vars_testing_dataset ):
23+ backend = 'sklearn'
24+ dataset = load_custom_vars_testing_dataset
25+ d = dice_ml .Data (dataframe = dataset , continuous_features = ['Numerical' ], outcome_name = 'Outcome' )
26+ ML_modelpath = helpers .get_custom_vars_dataset_modelpath_pipeline ()
27+ m = dice_ml .Model (model_path = ML_modelpath , backend = backend )
28+ exp = dice_ml .Dice (d , m , method = 'kdtree' )
29+ return exp
30+
31+
2132@pytest .fixture ()
2233def KD_multi_classification_exp_object ():
2334 backend = 'sklearn'
@@ -194,3 +205,44 @@ def test_KD_tree_counterfactual_explanations_output(self, desired_range, sample_
194205 def test_zero_cfs (self , desired_class , desired_range , sample_custom_query_4 , total_CFs ):
195206 self .exp_regr ._generate_counterfactuals (query_instance = sample_custom_query_4 , total_CFs = total_CFs ,
196207 desired_range = desired_range )
208+
209+
210+ class TestDiceKDBinaryVarsClassificationMethods :
211+ @pytest .fixture (autouse = True )
212+ def _initiate_exp_object (self , KD_binary_vars_classification_exp_object ):
213+ self .exp = KD_binary_vars_classification_exp_object # explainer object
214+ self .data_df_copy = self .exp .data_interface .data_df .copy ()
215+
216+ # When a query's feature value is not within the permitted range and the feature is not allowed to vary
217+ @pytest .mark .parametrize (("desired_range" , "desired_class" , "total_CFs" , "features_to_vary" , "permitted_range" ),
218+ [(None , 0 , 4 , ['Numerical' ], {'CategoricalNum' : ['1' , '2' ]})])
219+ def test_invalid_query_instance (self , desired_range , desired_class , sample_custom_vars_query_1 , total_CFs ,
220+ features_to_vary , permitted_range ):
221+ self .exp .dataset_with_predictions , self .exp .KD_tree , self .exp .predictions = \
222+ self .exp .build_KD_tree (self .data_df_copy , desired_range , desired_class , self .exp .predicted_outcome_name )
223+
224+ with pytest .raises (ValueError , match = "is outside the permitted range and isn't allowed to vary" ):
225+ self .exp ._generate_counterfactuals (query_instance = sample_custom_vars_query_1 , total_CFs = total_CFs ,
226+ features_to_vary = features_to_vary , permitted_range = permitted_range )
227+
228+ # Verifying the output of the KD tree
229+ @pytest .mark .parametrize (("desired_class" , "total_CFs" ), [(0 , 1 )])
230+ @pytest .mark .parametrize ('posthoc_sparsity_algorithm' , ['linear' , 'binary' , None ])
231+ def test_KD_tree_output (self , desired_class , sample_custom_vars_query_1 , total_CFs , posthoc_sparsity_algorithm ):
232+ self .exp ._generate_counterfactuals (query_instance = sample_custom_vars_query_1 , desired_class = desired_class ,
233+ total_CFs = total_CFs ,
234+ posthoc_sparsity_algorithm = posthoc_sparsity_algorithm )
235+ self .exp .final_cfs_df .Numerical = self .exp .final_cfs_df .Numerical .astype (int )
236+ expected_output = self .exp .data_interface .data_df
237+
238+ assert all (self .exp .final_cfs_df .Numerical == expected_output .Numerical [0 ])
239+ assert all (self .exp .final_cfs_df .Categorical == expected_output .Categorical [0 ])
240+
241+ # Verifying the output of the KD tree
242+ @pytest .mark .parametrize (("desired_class" , "total_CFs" ), [(0 , 1 )])
243+ def test_KD_tree_counterfactual_explanations_output (self , desired_class , sample_custom_vars_query_1 , total_CFs ):
244+ counterfactual_explanations = self .exp .generate_counterfactuals (
245+ query_instances = sample_custom_vars_query_1 , desired_class = desired_class ,
246+ total_CFs = total_CFs )
247+
248+ assert counterfactual_explanations is not None
0 commit comments