Skip to content

Commit fc69958

Browse files
authored
Merge pull request #236 from interpretml/gaugup/RaiseExceptionWhenNoCfsFound
Raise user exception when no counterfactuals are computed for any query point
2 parents 5954ba1 + f10375f commit fc69958

3 files changed

Lines changed: 49 additions & 1 deletion

File tree

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def generate_counterfactuals(self, query_instances, total_CFs,
8787
query_instances_list.append(query_instances[ix:(ix+1)])
8888
elif isinstance(query_instances, Iterable):
8989
query_instances_list = query_instances
90+
9091
for query_instance in tqdm(query_instances_list):
9192
self.data_interface.set_continuous_feature_indexes(query_instance)
9293
res = self._generate_counterfactuals(
@@ -101,6 +102,9 @@ def generate_counterfactuals(self, query_instances, total_CFs,
101102
verbose=verbose,
102103
**kwargs)
103104
cf_examples_arr.append(res)
105+
106+
self._check_any_counterfactuals_computed(cf_examples_arr=cf_examples_arr)
107+
104108
return CounterfactualExplanations(cf_examples_list=cf_examples_arr)
105109

106110
@abstractmethod
@@ -695,3 +699,15 @@ def round_to_precision(self):
695699
self.final_cfs_df[feature] = self.final_cfs_df[feature].astype(float).round(precisions[ix])
696700
if self.final_cfs_df_sparse is not None:
697701
self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(precisions[ix])
702+
703+
def _check_any_counterfactuals_computed(self, cf_examples_arr):
704+
"""Check if any counterfactuals were generated for any query point."""
705+
no_cf_generated = True
706+
# Check if any counterfactuals were generated for any query point
707+
for cf_examples in cf_examples_arr:
708+
if cf_examples.final_cfs_df is not None and len(cf_examples.final_cfs_df) > 0:
709+
no_cf_generated = False
710+
break
711+
if no_cf_generated:
712+
raise UserConfigValidationException(
713+
"No counterfactuals found for any of the query points! Kindly check your configuration.")

docs/source/notebooks/DiCE_getting_started.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@
285285
"e2 = exp.generate_counterfactuals(x_test[0:1],\n",
286286
" total_CFs=2,\n",
287287
" desired_class=\"opposite\",\n",
288-
" features_to_vary=[\"age\", \"education\", \"race\"]\n",
288+
" features_to_vary=[\"education\", \"occupation\"]\n",
289289
" )\n",
290290
"e2.visualize_as_dataframe(show_only_changes=True)"
291291
]

tests/test_dice_interface/test_explainer_base.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import dice_ml
66
from dice_ml.utils.exception import UserConfigValidationException
7+
from dice_ml.diverse_counterfactuals import CounterfactualExamples
78
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
89

910

@@ -15,6 +16,37 @@ def _verify_feature_importance(self, feature_importance):
1516
for key in feature_importance:
1617
assert feature_importance[key] >= 0.0 and feature_importance[key] <= 1.0
1718

19+
def test_check_any_counterfactuals_computed(
20+
self, method,
21+
custom_public_data_interface,
22+
sklearn_binary_classification_model_interface
23+
):
24+
exp = dice_ml.Dice(
25+
custom_public_data_interface,
26+
sklearn_binary_classification_model_interface,
27+
method=method)
28+
29+
sample_custom_query = custom_public_data_interface.data_df[0:1]
30+
cf_example = CounterfactualExamples(
31+
data_interface=custom_public_data_interface,
32+
test_instance_df=sample_custom_query)
33+
cf_examples_arr = [cf_example]
34+
35+
with pytest.raises(
36+
UserConfigValidationException,
37+
match="No counterfactuals found for any of the query points! Kindly check your configuration."):
38+
exp._check_any_counterfactuals_computed(cf_examples_arr=cf_examples_arr)
39+
40+
cf_example_has_cf = CounterfactualExamples(
41+
data_interface=custom_public_data_interface,
42+
final_cfs_df=sample_custom_query,
43+
test_instance_df=sample_custom_query)
44+
cf_example_no_cf = CounterfactualExamples(
45+
data_interface=custom_public_data_interface,
46+
test_instance_df=sample_custom_query)
47+
cf_examples_arr = [cf_example_has_cf, cf_example_no_cf]
48+
exp._check_any_counterfactuals_computed(cf_examples_arr=cf_examples_arr)
49+
1850
@pytest.mark.parametrize("desired_class", [1])
1951
def test_zero_totalcfs(
2052
self, desired_class, method, sample_custom_query_1,

0 commit comments

Comments
 (0)