|
10 | 10 | from sklearn.neighbors import KDTree |
11 | 11 | from tqdm import tqdm |
12 | 12 |
|
13 | | -from dice_ml.constants import ModelTypes |
| 13 | +from dice_ml.constants import ModelTypes, _PostHocSparsityTypes |
14 | 14 | from dice_ml.counterfactual_explanations import CounterfactualExplanations |
15 | 15 | from dice_ml.utils.exception import UserConfigValidationException |
16 | 16 |
|
@@ -46,6 +46,41 @@ def __init__(self, data_interface, model_interface=None): |
46 | 46 | # self.cont_precisions = \ |
47 | 47 | # [self.data_interface.get_decimal_precisions()[ix] for ix in self.encoded_continuous_feature_indexes] |
48 | 48 |
|
| 49 | + def _validate_counterfactual_configuration( |
| 50 | + self, query_instances, total_CFs, |
| 51 | + desired_class="opposite", desired_range=None, |
| 52 | + permitted_range=None, features_to_vary="all", |
| 53 | + stopping_threshold=0.5, posthoc_sparsity_param=0.1, |
| 54 | + posthoc_sparsity_algorithm="linear", verbose=False, **kwargs): |
| 55 | + |
| 56 | + if total_CFs <= 0: |
| 57 | + raise UserConfigValidationException( |
| 58 | + "The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.") |
| 59 | + |
| 60 | + if posthoc_sparsity_algorithm not in _PostHocSparsityTypes.ALL: |
| 61 | + raise UserConfigValidationException( |
| 62 | + 'The posthoc_sparsity_algorithm should be {0} and not {1}'.format( |
| 63 | + ' or '.join(_PostHocSparsityTypes.ALL), posthoc_sparsity_algorithm) |
| 64 | + ) |
| 65 | + |
| 66 | + if stopping_threshold < 0.0 or stopping_threshold > 1.0: |
| 67 | + raise UserConfigValidationException('The stopping_threshold should lie between {0} and {1}'.format( |
| 68 | + str(0.0), str(1.0))) |
| 69 | + |
| 70 | + if posthoc_sparsity_param is not None and (posthoc_sparsity_param < 0.0 or posthoc_sparsity_param > 1.0): |
| 71 | + raise UserConfigValidationException('The posthoc_sparsity_param should lie between {0} and {1}'.format( |
| 72 | + str(0.0), str(1.0))) |
| 73 | + |
| 74 | + if self.model is not None and self.model.model_type == ModelTypes.Classifier: |
| 75 | + if desired_range is not None: |
| 76 | + raise UserConfigValidationException( |
| 77 | + 'The desired_range parameter should not be set for classification task') |
| 78 | + |
| 79 | + if self.model is not None and self.model.model_type == ModelTypes.Regressor: |
| 80 | + if desired_range is None: |
| 81 | + raise UserConfigValidationException( |
| 82 | + 'The desired_range parameter should be set for regression task') |
| 83 | + |
49 | 84 | def generate_counterfactuals(self, query_instances, total_CFs, |
50 | 85 | desired_class="opposite", desired_range=None, |
51 | 86 | permitted_range=None, features_to_vary="all", |
@@ -89,9 +124,17 @@ def generate_counterfactuals(self, query_instances, total_CFs, |
89 | 124 | :returns: A CounterfactualExplanations object that contains the list of |
90 | 125 | counterfactual examples per query_instance as one of its attributes. |
91 | 126 | """ |
92 | | - if total_CFs <= 0: |
93 | | - raise UserConfigValidationException( |
94 | | - "The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.") |
| 127 | + self._validate_counterfactual_configuration( |
| 128 | + query_instances=query_instances, |
| 129 | + total_CFs=total_CFs, |
| 130 | + desired_class=desired_class, |
| 131 | + desired_range=desired_range, |
| 132 | + permitted_range=permitted_range, features_to_vary=features_to_vary, |
| 133 | + stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param, |
| 134 | + posthoc_sparsity_algorithm=posthoc_sparsity_algorithm, verbose=verbose, |
| 135 | + kwargs=kwargs |
| 136 | + ) |
| 137 | + |
95 | 138 | cf_examples_arr = [] |
96 | 139 | query_instances_list = [] |
97 | 140 | if isinstance(query_instances, pd.DataFrame): |
@@ -218,6 +261,16 @@ def local_feature_importance(self, query_instances, cf_examples_list=None, |
218 | 261 | the list of counterfactuals per input, local feature importances per |
219 | 262 | input, and the global feature importance summarized over all inputs. |
220 | 263 | """ |
| 264 | + self._validate_counterfactual_configuration( |
| 265 | + query_instances=query_instances, |
| 266 | + total_CFs=total_CFs, |
| 267 | + desired_class=desired_class, |
| 268 | + desired_range=desired_range, |
| 269 | + permitted_range=permitted_range, features_to_vary=features_to_vary, |
| 270 | + stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param, |
| 271 | + posthoc_sparsity_algorithm=posthoc_sparsity_algorithm, |
| 272 | + kwargs=kwargs |
| 273 | + ) |
221 | 274 | if cf_examples_list is not None: |
222 | 275 | if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]): |
223 | 276 | raise UserConfigValidationException( |
@@ -267,6 +320,16 @@ def global_feature_importance(self, query_instances, cf_examples_list=None, |
267 | 320 | the list of counterfactuals per input, local feature importances per |
268 | 321 | input, and the global feature importance summarized over all inputs. |
269 | 322 | """ |
| 323 | + self._validate_counterfactual_configuration( |
| 324 | + query_instances=query_instances, |
| 325 | + total_CFs=total_CFs, |
| 326 | + desired_class=desired_class, |
| 327 | + desired_range=desired_range, |
| 328 | + permitted_range=permitted_range, features_to_vary=features_to_vary, |
| 329 | + stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param, |
| 330 | + posthoc_sparsity_algorithm=posthoc_sparsity_algorithm, |
| 331 | + kwargs=kwargs |
| 332 | + ) |
270 | 333 | if query_instances is not None and len(query_instances) < 10: |
271 | 334 | raise UserConfigValidationException( |
272 | 335 | "The number of query instances should be greater than or equal to 10 " |
@@ -323,6 +386,16 @@ def feature_importance(self, query_instances, cf_examples_list=None, |
323 | 386 | the list of counterfactuals per input, local feature importances per |
324 | 387 | input, and the global feature importance summarized over all inputs. |
325 | 388 | """ |
| 389 | + self._validate_counterfactual_configuration( |
| 390 | + query_instances=query_instances, |
| 391 | + total_CFs=total_CFs, |
| 392 | + desired_class=desired_class, |
| 393 | + desired_range=desired_range, |
| 394 | + permitted_range=permitted_range, features_to_vary=features_to_vary, |
| 395 | + stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param, |
| 396 | + posthoc_sparsity_algorithm=posthoc_sparsity_algorithm, |
| 397 | + kwargs=kwargs |
| 398 | + ) |
326 | 399 | if cf_examples_list is None: |
327 | 400 | cf_examples_list = self.generate_counterfactuals( |
328 | 401 | query_instances, total_CFs, |
|
0 commit comments