Skip to content

Commit 8c92630

Browse files
gaugupamit-sharma
andauthored
Add additional parameters to generate_counterfactuals() function (#251)
* Add additional parameters to generate_counterfactuals() function Signed-off-by: gaugup <gaugup@microsoft.com> * added the docs for methods * Fix flake8 error Signed-off-by: Gaurav Gupta <gaugup@microsoft.com> Co-authored-by: Amit Sharma <amit_sharma@live.com>
1 parent e6a1dc3 commit 8c92630

2 files changed

Lines changed: 14 additions & 0 deletions

File tree

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def generate_counterfactuals(self, query_instances, total_CFs,
5050
desired_class="opposite", desired_range=None,
5151
permitted_range=None, features_to_vary="all",
5252
stopping_threshold=0.5, posthoc_sparsity_param=0.1,
53+
proximity_weight=0.2, sparsity_weight=0.2, diversity_weight=5.0,
54+
categorical_penalty=0.1,
5355
posthoc_sparsity_algorithm="linear", verbose=False, **kwargs):
5456
"""General method for generating counterfactuals.
5557
@@ -65,6 +67,15 @@ def generate_counterfactuals(self, query_instances, total_CFs,
6567
If None, uses the parameters initialized in data_interface.
6668
:param features_to_vary: Either a string "all" or a list of feature names to vary.
6769
:param stopping_threshold: Minimum threshold for counterfactuals target class probability.
70+
:param proximity_weight: A positive float. Larger this weight, more close the counterfactuals are to the
71+
query_instance. Used by ['genetic', 'gradientdescent'],
72+
ignored by ['random', 'kdtree'] methods.
73+
:param sparsity_weight: A positive float. Larger this weight, less features are changed from the query_instance.
74+
Used by ['genetic', 'kdtree'], ignored by ['random', 'gradientdescent'] methods.
75+
:param diversity_weight: A positive float. Larger this weight, more diverse the counterfactuals are.
76+
Used by ['genetic', 'gradientdescent'], ignored by ['random', 'kdtree'] methods.
77+
:param categorical_penalty: A positive float. A weight to ensure that all levels of a categorical variable sums to 1.
78+
Used by ['genetic', 'gradientdescent'], ignored by ['random', 'kdtree'] methods.
6879
:param posthoc_sparsity_param: Parameter for the post-hoc operation on continuous features to enhance sparsity.
6980
:param posthoc_sparsity_algorithm: Perform either linear or binary search. Takes "linear" or "binary".
7081
Prefer binary search when a feature range is large (for instance,

tests/test_dice_interface/test_explainer_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,9 @@ def test_desired_class(
365365
ans = exp.generate_counterfactuals(query_instances=sample_custom_query_2,
366366
features_to_vary='all',
367367
total_CFs=2, desired_class=desired_class,
368+
proximity_weight=0.2, sparsity_weight=0.2,
369+
diversity_weight=5.0,
370+
categorical_penalty=0.1,
368371
permitted_range=None)
369372
if method != 'kdtree':
370373
assert all(ans.cf_examples_list[0].final_cfs_df[exp.data_interface.outcome_name].values == [desired_class] * 2)

0 commit comments

Comments
 (0)