1010from sklearn .neighbors import KDTree
1111from tqdm import tqdm
1212
13- from dice_ml .constants import ModelTypes
13+ from dice_ml .constants import ModelTypes , _PostHocSparsityTypes
1414from dice_ml .counterfactual_explanations import CounterfactualExplanations
1515from dice_ml .utils .exception import UserConfigValidationException
1616
@@ -57,7 +57,6 @@ def _validate_counterfactual_configuration(
5757 raise UserConfigValidationException (
5858 "The number of counterfactuals generated per query instance (total_CFs) should be a positive integer." )
5959
60- from dice_ml .constants import _PostHocSparsityTypes
6160 if posthoc_sparsity_algorithm not in _PostHocSparsityTypes .ALL :
6261 raise UserConfigValidationException (
6362 'The posthoc_sparsity_algorithm should be {0} and not {1}' .format (
@@ -72,6 +71,16 @@ def _validate_counterfactual_configuration(
7271 raise UserConfigValidationException ('The posthoc_sparsity_param should lie between {0} and {1}' .format (
7372 str (0.0 ), str (1.0 )))
7473
74+ if self .model is not None and self .model .model_type == ModelTypes .Classifier :
75+ if desired_range is not None :
76+ raise UserConfigValidationException (
77+ 'The desired_range parameter should not be set for classification task' )
78+
79+ if self .model is not None and self .model .model_type == ModelTypes .Regressor :
80+ if desired_range is None :
81+ raise UserConfigValidationException (
82+ 'The desired_range parameter should be set for regression task' )
83+
7584 def generate_counterfactuals (self , query_instances , total_CFs ,
7685 desired_class = "opposite" , desired_range = None ,
7786 permitted_range = None , features_to_vary = "all" ,
@@ -104,9 +113,17 @@ def generate_counterfactuals(self, query_instances, total_CFs,
104113 :returns: A CounterfactualExplanations object that contains the list of
105114 counterfactual examples per query_instance as one of its attributes.
106115 """
107- if total_CFs <= 0 :
108- raise UserConfigValidationException (
109- "The number of counterfactuals generated per query instance (total_CFs) should be a positive integer." )
116+ self ._validate_counterfactual_configuration (
117+ query_instances = query_instances ,
118+ total_CFs = total_CFs ,
119+ desired_class = desired_class ,
120+ desired_range = desired_range ,
121+ permitted_range = permitted_range , features_to_vary = features_to_vary ,
122+ stopping_threshold = stopping_threshold , posthoc_sparsity_param = posthoc_sparsity_param ,
123+ posthoc_sparsity_algorithm = posthoc_sparsity_algorithm , verbose = verbose ,
124+ kwargs = kwargs
125+ )
126+
110127 cf_examples_arr = []
111128 query_instances_list = []
112129 if isinstance (query_instances , pd .DataFrame ):
0 commit comments