44
55import dice_ml
66from dice_ml .utils .exception import UserConfigValidationException
7+ from dice_ml .diverse_counterfactuals import CounterfactualExamples
78from dice_ml .explainer_interfaces .explainer_base import ExplainerBase
89
910
@@ -15,6 +16,37 @@ def _verify_feature_importance(self, feature_importance):
1516 for key in feature_importance :
1617 assert feature_importance [key ] >= 0.0 and feature_importance [key ] <= 1.0
1718
19+ def test_check_any_counterfactuals_computed (
20+ self , method ,
21+ custom_public_data_interface ,
22+ sklearn_binary_classification_model_interface
23+ ):
24+ exp = dice_ml .Dice (
25+ custom_public_data_interface ,
26+ sklearn_binary_classification_model_interface ,
27+ method = method )
28+
29+ sample_custom_query = custom_public_data_interface .data_df [0 :1 ]
30+ cf_example = CounterfactualExamples (
31+ data_interface = custom_public_data_interface ,
32+ test_instance_df = sample_custom_query )
33+ cf_examples_arr = [cf_example ]
34+
35+ with pytest .raises (
36+ UserConfigValidationException ,
37+ match = "No counterfactuals found for any of the query points! Kindly check your configuration." ):
38+ exp ._check_any_counterfactuals_computed (cf_examples_arr = cf_examples_arr )
39+
40+ cf_example_has_cf = CounterfactualExamples (
41+ data_interface = custom_public_data_interface ,
42+ final_cfs_df = sample_custom_query ,
43+ test_instance_df = sample_custom_query )
44+ cf_example_no_cf = CounterfactualExamples (
45+ data_interface = custom_public_data_interface ,
46+ test_instance_df = sample_custom_query )
47+ cf_examples_arr = [cf_example_has_cf , cf_example_no_cf ]
48+ exp ._check_any_counterfactuals_computed (cf_examples_arr = cf_examples_arr )
49+
1850 @pytest .mark .parametrize ("desired_class" , [1 ])
1951 def test_zero_totalcfs (
2052 self , desired_class , method , sample_custom_query_1 ,
0 commit comments