Skip to content

Commit 6b5a521

Browse files
authored
Merge pull request #296 from interpretml/gaugup/AddEmptyFeaturesToVaryListValidations
Raise user config validation exception when features_to_vary list is empty
2 parents 5e70ef4 + 84e8763 commit 6b5a521

2 files changed

Lines changed: 13 additions & 2 deletions

File tree

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ def _validate_counterfactual_configuration(
5858
raise UserConfigValidationException(
5959
"The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.")
6060

61+
if features_to_vary != "all":
62+
if len(features_to_vary) == 0:
63+
raise UserConfigValidationException("Some features need to be varied for generating counterfactuals.")
64+
6165
if posthoc_sparsity_algorithm not in _PostHocSparsityTypes.ALL:
6266
raise UserConfigValidationException(
6367
'The posthoc_sparsity_algorithm should be {0} and not {1}'.format(

tests/test_dice_interface/test_explainer_base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,12 @@ def test_generate_counterfactuals_user_config_validations(
547547
explainer_function(query_instances=sample_custom_query_2,
548548
total_CFs=10, desired_range=[0, 10])
549549

550+
with pytest.raises(
551+
UserConfigValidationException,
552+
match=r'Some features need to be varied for generating counterfactuals.'):
553+
explainer_function(query_instances=sample_custom_query_2,
554+
total_CFs=10, features_to_vary=[])
555+
550556
@pytest.mark.parametrize('explainer_function',
551557
['generate_counterfactuals', 'local_feature_importance',
552558
'feature_importance', 'global_feature_importance'])
@@ -572,6 +578,9 @@ def test_generate_counterfactuals_user_config_validations_regression(
572578
explainer_function(query_instances=sample_custom_query_1,
573579
total_CFs=10, desired_range=[4, 3])
574580

581+
582+
@pytest.mark.parametrize("method", ['random', 'genetic', 'kdtree'])
583+
class TestExplainerBaseDataValidations:
575584
def test_global_feature_importance_error_conditions_with_insufficient_query_points(
576585
self, method,
577586
sample_custom_query_1,
@@ -665,5 +674,3 @@ def test_local_feature_importance_error_conditions_with_insufficient_cfs_per_que
665674
exp.local_feature_importance(
666675
query_instances=sample_custom_query_1,
667676
total_CFs=1)
668-
669-
# class TestExplainerBaseDataValidations:

0 commit comments

Comments
 (0)