Skip to content

Commit ccf68a8

Browse files
authored
Merge pull request #262 from interpretml/gaugup/AddValidationsToGenerateCounterfactuals
Add validations for input parameters in generate_counterfactuals()
2 parents 8c92630 + da3ad79 commit ccf68a8

3 files changed

Lines changed: 261 additions & 124 deletions

File tree

dice_ml/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,10 @@ class _SchemaVersions:
2929
CURRENT_VERSION = V2
3030

3131
ALL_VERSIONS = [V1, V2]
32+
33+
34+
class _PostHocSparsityTypes:
35+
LINEAR = 'linear'
36+
BINARY = 'binary'
37+
38+
ALL = [LINEAR, BINARY]

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sklearn.neighbors import KDTree
1111
from tqdm import tqdm
1212

13-
from dice_ml.constants import ModelTypes
13+
from dice_ml.constants import ModelTypes, _PostHocSparsityTypes
1414
from dice_ml.counterfactual_explanations import CounterfactualExplanations
1515
from dice_ml.utils.exception import UserConfigValidationException
1616

@@ -46,6 +46,41 @@ def __init__(self, data_interface, model_interface=None):
4646
# self.cont_precisions = \
4747
# [self.data_interface.get_decimal_precisions()[ix] for ix in self.encoded_continuous_feature_indexes]
4848

49+
def _validate_counterfactual_configuration(
50+
self, query_instances, total_CFs,
51+
desired_class="opposite", desired_range=None,
52+
permitted_range=None, features_to_vary="all",
53+
stopping_threshold=0.5, posthoc_sparsity_param=0.1,
54+
posthoc_sparsity_algorithm="linear", verbose=False, **kwargs):
55+
56+
if total_CFs <= 0:
57+
raise UserConfigValidationException(
58+
"The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.")
59+
60+
if posthoc_sparsity_algorithm not in _PostHocSparsityTypes.ALL:
61+
raise UserConfigValidationException(
62+
'The posthoc_sparsity_algorithm should be {0} and not {1}'.format(
63+
' or '.join(_PostHocSparsityTypes.ALL), posthoc_sparsity_algorithm)
64+
)
65+
66+
if stopping_threshold < 0.0 or stopping_threshold > 1.0:
67+
raise UserConfigValidationException('The stopping_threshold should lie between {0} and {1}'.format(
68+
str(0.0), str(1.0)))
69+
70+
if posthoc_sparsity_param is not None and (posthoc_sparsity_param < 0.0 or posthoc_sparsity_param > 1.0):
71+
raise UserConfigValidationException('The posthoc_sparsity_param should lie between {0} and {1}'.format(
72+
str(0.0), str(1.0)))
73+
74+
if self.model is not None and self.model.model_type == ModelTypes.Classifier:
75+
if desired_range is not None:
76+
raise UserConfigValidationException(
77+
'The desired_range parameter should not be set for classification task')
78+
79+
if self.model is not None and self.model.model_type == ModelTypes.Regressor:
80+
if desired_range is None:
81+
raise UserConfigValidationException(
82+
'The desired_range parameter should be set for regression task')
83+
4984
def generate_counterfactuals(self, query_instances, total_CFs,
5085
desired_class="opposite", desired_range=None,
5186
permitted_range=None, features_to_vary="all",
@@ -89,9 +124,17 @@ def generate_counterfactuals(self, query_instances, total_CFs,
89124
:returns: A CounterfactualExplanations object that contains the list of
90125
counterfactual examples per query_instance as one of its attributes.
91126
"""
92-
if total_CFs <= 0:
93-
raise UserConfigValidationException(
94-
"The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.")
127+
self._validate_counterfactual_configuration(
128+
query_instances=query_instances,
129+
total_CFs=total_CFs,
130+
desired_class=desired_class,
131+
desired_range=desired_range,
132+
permitted_range=permitted_range, features_to_vary=features_to_vary,
133+
stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param,
134+
posthoc_sparsity_algorithm=posthoc_sparsity_algorithm, verbose=verbose,
135+
kwargs=kwargs
136+
)
137+
95138
cf_examples_arr = []
96139
query_instances_list = []
97140
if isinstance(query_instances, pd.DataFrame):
@@ -218,6 +261,16 @@ def local_feature_importance(self, query_instances, cf_examples_list=None,
218261
the list of counterfactuals per input, local feature importances per
219262
input, and the global feature importance summarized over all inputs.
220263
"""
264+
self._validate_counterfactual_configuration(
265+
query_instances=query_instances,
266+
total_CFs=total_CFs,
267+
desired_class=desired_class,
268+
desired_range=desired_range,
269+
permitted_range=permitted_range, features_to_vary=features_to_vary,
270+
stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param,
271+
posthoc_sparsity_algorithm=posthoc_sparsity_algorithm,
272+
kwargs=kwargs
273+
)
221274
if cf_examples_list is not None:
222275
if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]):
223276
raise UserConfigValidationException(
@@ -267,6 +320,16 @@ def global_feature_importance(self, query_instances, cf_examples_list=None,
267320
the list of counterfactuals per input, local feature importances per
268321
input, and the global feature importance summarized over all inputs.
269322
"""
323+
self._validate_counterfactual_configuration(
324+
query_instances=query_instances,
325+
total_CFs=total_CFs,
326+
desired_class=desired_class,
327+
desired_range=desired_range,
328+
permitted_range=permitted_range, features_to_vary=features_to_vary,
329+
stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param,
330+
posthoc_sparsity_algorithm=posthoc_sparsity_algorithm,
331+
kwargs=kwargs
332+
)
270333
if query_instances is not None and len(query_instances) < 10:
271334
raise UserConfigValidationException(
272335
"The number of query instances should be greater than or equal to 10 "
@@ -323,6 +386,16 @@ def feature_importance(self, query_instances, cf_examples_list=None,
323386
the list of counterfactuals per input, local feature importances per
324387
input, and the global feature importance summarized over all inputs.
325388
"""
389+
self._validate_counterfactual_configuration(
390+
query_instances=query_instances,
391+
total_CFs=total_CFs,
392+
desired_class=desired_class,
393+
desired_range=desired_range,
394+
permitted_range=permitted_range, features_to_vary=features_to_vary,
395+
stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param,
396+
posthoc_sparsity_algorithm=posthoc_sparsity_algorithm,
397+
kwargs=kwargs
398+
)
326399
if cf_examples_list is None:
327400
cf_examples_list = self.generate_counterfactuals(
328401
query_instances, total_CFs,

0 commit comments

Comments
 (0)