Skip to content

Commit 5954ba1

Browse files
authored
Merge pull request #234 from interpretml/gaugup/ImproveFeatureImpErrorMesssages
Improve error messages in feature importance functions
2 parents 27f735c + 0ddb13b commit 5954ba1

2 files changed

Lines changed: 122 additions & 8 deletions

File tree

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,12 @@ def local_feature_importance(self, query_instances, cf_examples_list=None,
206206
if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]):
207207
raise UserConfigValidationException(
208208
"The number of counterfactuals generated per query instance should be "
209-
"greater than or equal to 10")
209+
"greater than or equal to 10 to compute feature importance for all query points")
210210
elif total_CFs < 10:
211-
raise UserConfigValidationException("The number of counterfactuals generated per "
212-
"query instance should be greater than or equal to 10")
211+
raise UserConfigValidationException(
212+
"The number of counterfactuals requested per "
213+
"query instance should be greater than or equal to 10 "
214+
"to compute feature importance for all query points")
213215
importances = self.feature_importance(
214216
query_instances,
215217
cf_examples_list=cf_examples_list,
@@ -250,16 +252,25 @@ def global_feature_importance(self, query_instances, cf_examples_list=None,
250252
input, and the global feature importance summarized over all inputs.
251253
"""
252254
if query_instances is not None and len(query_instances) < 10:
253-
raise UserConfigValidationException("The number of query instances should be greater than or equal to 10")
255+
raise UserConfigValidationException(
256+
"The number of query instances should be greater than or equal to 10 "
257+
"to compute global feature importance over all query points")
254258
if cf_examples_list is not None:
255-
if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]):
259+
if len(cf_examples_list) < 10:
260+
raise UserConfigValidationException(
261+
"The number of points for which counterfactuals generated should be "
262+
"greater than or equal to 10 "
263+
"to compute global feature importance")
264+
elif any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]):
256265
raise UserConfigValidationException(
257266
"The number of counterfactuals generated per query instance should be "
258-
"greater than or equal to 10")
267+
"greater than or equal to 10 "
268+
"to compute global feature importance over all query points")
259269
elif total_CFs < 10:
260270
raise UserConfigValidationException(
261-
"The number of counterfactuals generated per query instance should be greater "
262-
"than or equal to 10")
271+
"The number of counterfactuals requested per query instance should be greater "
272+
"than or equal to 10 "
273+
"to compute global feature importance over all query points")
263274
importances = self.feature_importance(
264275
query_instances,
265276
cf_examples_list=cf_examples_list,

tests/test_dice_interface/test_explainer_base.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,109 @@ def test_global_feature_importance(
9696

9797
self._verify_feature_importance(global_importance.summary_importance)
9898

99+
@pytest.mark.parametrize("desired_class", [1])
100+
def test_global_feature_importance_error_conditions_with_insufficient_query_points(
101+
self, desired_class, method,
102+
sample_custom_query_1,
103+
custom_public_data_interface,
104+
sklearn_binary_classification_model_interface):
105+
exp = dice_ml.Dice(
106+
custom_public_data_interface,
107+
sklearn_binary_classification_model_interface,
108+
method=method)
109+
110+
cf_explanations = exp.generate_counterfactuals(
111+
query_instances=sample_custom_query_1,
112+
total_CFs=15,
113+
desired_class=desired_class)
114+
115+
with pytest.raises(
116+
UserConfigValidationException,
117+
match="The number of points for which counterfactuals generated should be "
118+
"greater than or equal to 10 "
119+
"to compute global feature importance"):
120+
exp.global_feature_importance(
121+
query_instances=None,
122+
cf_examples_list=cf_explanations.cf_examples_list)
123+
124+
with pytest.raises(
125+
UserConfigValidationException,
126+
match="The number of query instances should be greater than or equal to 10 "
127+
"to compute global feature importance over all query points"):
128+
exp.global_feature_importance(
129+
query_instances=sample_custom_query_1,
130+
total_CFs=15,
131+
desired_class=desired_class)
132+
133+
@pytest.mark.parametrize("desired_class", [1])
134+
def test_global_feature_importance_error_conditions_with_insufficient_cfs_per_query_point(
135+
self, desired_class, method,
136+
sample_custom_query_10,
137+
custom_public_data_interface,
138+
sklearn_binary_classification_model_interface):
139+
exp = dice_ml.Dice(
140+
custom_public_data_interface,
141+
sklearn_binary_classification_model_interface,
142+
method=method)
143+
144+
cf_explanations = exp.generate_counterfactuals(
145+
query_instances=sample_custom_query_10,
146+
total_CFs=1,
147+
desired_class=desired_class)
148+
149+
with pytest.raises(
150+
UserConfigValidationException,
151+
match="The number of counterfactuals generated per query instance should be "
152+
"greater than or equal to 10 "
153+
"to compute global feature importance over all query points"):
154+
exp.global_feature_importance(
155+
query_instances=None,
156+
cf_examples_list=cf_explanations.cf_examples_list)
157+
158+
with pytest.raises(
159+
UserConfigValidationException,
160+
match="The number of counterfactuals requested per query instance should be greater "
161+
"than or equal to 10 "
162+
"to compute global feature importance over all query points"):
163+
exp.global_feature_importance(
164+
query_instances=sample_custom_query_10,
165+
total_CFs=1,
166+
desired_class=desired_class)
167+
168+
@pytest.mark.parametrize("desired_class", [1])
169+
def test_local_feature_importance_error_conditions_with_insufficient_cfs_per_query_point(
170+
self, desired_class, method,
171+
sample_custom_query_1,
172+
custom_public_data_interface,
173+
sklearn_binary_classification_model_interface):
174+
exp = dice_ml.Dice(
175+
custom_public_data_interface,
176+
sklearn_binary_classification_model_interface,
177+
method=method)
178+
179+
cf_explanations = exp.generate_counterfactuals(
180+
query_instances=sample_custom_query_1,
181+
total_CFs=1,
182+
desired_class=desired_class)
183+
184+
with pytest.raises(
185+
UserConfigValidationException,
186+
match="The number of counterfactuals generated per query instance should be "
187+
"greater than or equal to 10 to compute feature importance for all query points"):
188+
exp.local_feature_importance(
189+
query_instances=None,
190+
cf_examples_list=cf_explanations.cf_examples_list)
191+
192+
with pytest.raises(
193+
UserConfigValidationException,
194+
match="The number of counterfactuals requested per "
195+
"query instance should be greater than or equal to 10 "
196+
"to compute feature importance for all query points"):
197+
exp.local_feature_importance(
198+
query_instances=sample_custom_query_1,
199+
total_CFs=1,
200+
desired_class=desired_class)
201+
99202
# @pytest.mark.parametrize("desired_class, binary_classification_exp_object_out_of_order",
100203
# [(1, 'random'), (1, 'genetic'), (1, 'kdtree')],
101204
# indirect=['binary_classification_exp_object_out_of_order'])

0 commit comments

Comments
 (0)