@@ -96,6 +96,109 @@ def test_global_feature_importance(
9696
9797 self ._verify_feature_importance (global_importance .summary_importance )
9898
99+ @pytest .mark .parametrize ("desired_class" , [1 ])
100+ def test_global_feature_importance_error_conditions_with_insufficient_query_points (
101+ self , desired_class , method ,
102+ sample_custom_query_1 ,
103+ custom_public_data_interface ,
104+ sklearn_binary_classification_model_interface ):
105+ exp = dice_ml .Dice (
106+ custom_public_data_interface ,
107+ sklearn_binary_classification_model_interface ,
108+ method = method )
109+
110+ cf_explanations = exp .generate_counterfactuals (
111+ query_instances = sample_custom_query_1 ,
112+ total_CFs = 15 ,
113+ desired_class = desired_class )
114+
115+ with pytest .raises (
116+ UserConfigValidationException ,
117+ match = "The number of points for which counterfactuals generated should be "
118+ "greater than or equal to 10 "
119+ "to compute global feature importance" ):
120+ exp .global_feature_importance (
121+ query_instances = None ,
122+ cf_examples_list = cf_explanations .cf_examples_list )
123+
124+ with pytest .raises (
125+ UserConfigValidationException ,
126+ match = "The number of query instances should be greater than or equal to 10 "
127+ "to compute global feature importance over all query points" ):
128+ exp .global_feature_importance (
129+ query_instances = sample_custom_query_1 ,
130+ total_CFs = 15 ,
131+ desired_class = desired_class )
132+
133+ @pytest .mark .parametrize ("desired_class" , [1 ])
134+ def test_global_feature_importance_error_conditions_with_insufficient_cfs_per_query_point (
135+ self , desired_class , method ,
136+ sample_custom_query_10 ,
137+ custom_public_data_interface ,
138+ sklearn_binary_classification_model_interface ):
139+ exp = dice_ml .Dice (
140+ custom_public_data_interface ,
141+ sklearn_binary_classification_model_interface ,
142+ method = method )
143+
144+ cf_explanations = exp .generate_counterfactuals (
145+ query_instances = sample_custom_query_10 ,
146+ total_CFs = 1 ,
147+ desired_class = desired_class )
148+
149+ with pytest .raises (
150+ UserConfigValidationException ,
151+ match = "The number of counterfactuals generated per query instance should be "
152+ "greater than or equal to 10 "
153+ "to compute global feature importance over all query points" ):
154+ exp .global_feature_importance (
155+ query_instances = None ,
156+ cf_examples_list = cf_explanations .cf_examples_list )
157+
158+ with pytest .raises (
159+ UserConfigValidationException ,
160+ match = "The number of counterfactuals requested per query instance should be greater "
161+ "than or equal to 10 "
162+ "to compute global feature importance over all query points" ):
163+ exp .global_feature_importance (
164+ query_instances = sample_custom_query_10 ,
165+ total_CFs = 1 ,
166+ desired_class = desired_class )
167+
168+ @pytest .mark .parametrize ("desired_class" , [1 ])
169+ def test_local_feature_importance_error_conditions_with_insufficient_cfs_per_query_point (
170+ self , desired_class , method ,
171+ sample_custom_query_1 ,
172+ custom_public_data_interface ,
173+ sklearn_binary_classification_model_interface ):
174+ exp = dice_ml .Dice (
175+ custom_public_data_interface ,
176+ sklearn_binary_classification_model_interface ,
177+ method = method )
178+
179+ cf_explanations = exp .generate_counterfactuals (
180+ query_instances = sample_custom_query_1 ,
181+ total_CFs = 1 ,
182+ desired_class = desired_class )
183+
184+ with pytest .raises (
185+ UserConfigValidationException ,
186+ match = "The number of counterfactuals generated per query instance should be "
187+ "greater than or equal to 10 to compute feature importance for all query points" ):
188+ exp .local_feature_importance (
189+ query_instances = None ,
190+ cf_examples_list = cf_explanations .cf_examples_list )
191+
192+ with pytest .raises (
193+ UserConfigValidationException ,
194+ match = "The number of counterfactuals requested per "
195+ "query instance should be greater than or equal to 10 "
196+ "to compute feature importance for all query points" ):
197+ exp .local_feature_importance (
198+ query_instances = sample_custom_query_1 ,
199+ total_CFs = 1 ,
200+ desired_class = desired_class )
201+
99202 # @pytest.mark.parametrize("desired_class, binary_classification_exp_object_out_of_order",
100203 # [(1, 'random'), (1, 'genetic'), (1, 'kdtree')],
101204 # indirect=['binary_classification_exp_object_out_of_order'])
0 commit comments