Skip to content

Commit 9ea27bd

Browse files
authored
Merge pull request #271 from interpretml/gaugup/AddMoreDesiredRangeValidations
Add configuration validation for desired_range
2 parents 67b5a66 + 7f4014f commit 9ea27bd

3 files changed

Lines changed: 22 additions & 4 deletions

File tree

dice_ml/explainer_interfaces/dice_KD.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,6 @@ def _generate_counterfactuals(self, query_instance, total_CFs, desired_range=Non
8989

9090
query_instance[self.data_interface.outcome_name] = test_pred
9191
desired_class = self.misc_init(stopping_threshold, desired_class, desired_range, test_pred)
92-
if desired_range is not None:
93-
if desired_range[0] > desired_range[1]:
94-
raise ValueError("Invalid Range!")
9592

9693
if desired_class == "opposite" and self.model.model_type == ModelTypes.Classifier:
9794
if self.num_output_nodes == 2:

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ def _validate_counterfactual_configuration(
8181
raise UserConfigValidationException(
8282
'The desired_range parameter should be set for regression task')
8383

84+
if desired_range is not None:
85+
if len(desired_range) != 2:
86+
raise UserConfigValidationException(
87+
"The parameter desired_range needs to have two numbers in ascending order.")
88+
if desired_range[0] > desired_range[1]:
89+
raise UserConfigValidationException(
90+
"The range provided in desired_range should be in ascending order.")
91+
8492
def generate_counterfactuals(self, query_instances, total_CFs,
8593
desired_class="opposite", desired_range=None,
8694
permitted_range=None, features_to_vary="all",
@@ -96,7 +104,8 @@ def generate_counterfactuals(self, query_instances, total_CFs,
96104
:param desired_class: Desired counterfactual class - can take 0 or 1. Default value
97105
is "opposite" to the outcome class of query_instance for binary classification.
98106
:param desired_range: For regression problems. Contains the outcome range to
99-
generate counterfactuals in.
107+
generate counterfactuals in. This should be a list of two numbers in
108+
ascending order.
100109
:param permitted_range: Dictionary with feature names as keys and permitted range in list as values.
101110
Defaults to the range inferred from training data.
102111
If None, uses the parameters initialized in data_interface.

tests/test_dice_interface/test_explainer_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,18 @@ def test_generate_counterfactuals_user_config_validations_regression(
510510
explainer_function(query_instances=sample_custom_query_1,
511511
total_CFs=10)
512512

513+
with pytest.raises(
514+
UserConfigValidationException,
515+
match=r'The parameter desired_range needs to have two numbers in ascending order.'):
516+
explainer_function(query_instances=sample_custom_query_1,
517+
total_CFs=10, desired_range=[1, 3, 4])
518+
519+
with pytest.raises(
520+
UserConfigValidationException,
521+
match=r'The range provided in desired_range should be in ascending order.'):
522+
explainer_function(query_instances=sample_custom_query_1,
523+
total_CFs=10, desired_range=[4, 3])
524+
513525
def test_global_feature_importance_error_conditions_with_insufficient_query_points(
514526
self, method,
515527
sample_custom_query_1,

0 commit comments

Comments
 (0)