Skip to content

Commit 68dadc8

Browse files
committed
add more validations
Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>
1 parent e1d4e52 commit 68dadc8

1 file changed

Lines changed: 22 additions & 5 deletions

File tree

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sklearn.neighbors import KDTree
1111
from tqdm import tqdm
1212

13-
from dice_ml.constants import ModelTypes
13+
from dice_ml.constants import ModelTypes, _PostHocSparsityTypes
1414
from dice_ml.counterfactual_explanations import CounterfactualExplanations
1515
from dice_ml.utils.exception import UserConfigValidationException
1616

@@ -57,7 +57,6 @@ def _validate_counterfactual_configuration(
5757
raise UserConfigValidationException(
5858
"The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.")
5959

60-
from dice_ml.constants import _PostHocSparsityTypes
6160
if posthoc_sparsity_algorithm not in _PostHocSparsityTypes.ALL:
6261
raise UserConfigValidationException(
6362
'The posthoc_sparsity_algorithm should be {0} and not {1}'.format(
@@ -72,6 +71,16 @@ def _validate_counterfactual_configuration(
7271
raise UserConfigValidationException('The posthoc_sparsity_param should lie between {0} and {1}'.format(
7372
str(0.0), str(1.0)))
7473

74+
if self.model is not None and self.model.model_type == ModelTypes.Classifier:
75+
if desired_range is not None:
76+
raise UserConfigValidationException(
77+
'The desired_range parameter should not be set for classification task')
78+
79+
if self.model is not None and self.model.model_type == ModelTypes.Regressor:
80+
if desired_range is None:
81+
raise UserConfigValidationException(
82+
'The desired_range parameter should be set for regression task')
83+
7584
def generate_counterfactuals(self, query_instances, total_CFs,
7685
desired_class="opposite", desired_range=None,
7786
permitted_range=None, features_to_vary="all",
@@ -104,9 +113,17 @@ def generate_counterfactuals(self, query_instances, total_CFs,
104113
:returns: A CounterfactualExplanations object that contains the list of
105114
counterfactual examples per query_instance as one of its attributes.
106115
"""
107-
if total_CFs <= 0:
108-
raise UserConfigValidationException(
109-
"The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.")
116+
self._validate_counterfactual_configuration(
117+
query_instances=query_instances,
118+
total_CFs=total_CFs,
119+
desired_class=desired_class,
120+
desired_range=desired_range,
121+
permitted_range=permitted_range, features_to_vary=features_to_vary,
122+
stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param,
123+
posthoc_sparsity_algorithm=posthoc_sparsity_algorithm, verbose=verbose,
124+
kwargs=kwargs
125+
)
126+
110127
cf_examples_arr = []
111128
query_instances_list = []
112129
if isinstance(query_instances, pd.DataFrame):

0 commit comments

Comments
 (0)