Skip to content

Commit 53c694f

Browse files
committed
Add tests
Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>
1 parent 2acac6b commit 53c694f

2 files changed

Lines changed: 49 additions & 13 deletions

File tree

tests/test_dice_interface/test_dice_random.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import dice_ml
44
from dice_ml.counterfactual_explanations import CounterfactualExplanations
55
from dice_ml.diverse_counterfactuals import CounterfactualExamples
6-
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
76
from dice_ml.utils import helpers
87
from dice_ml.utils.exception import UserConfigValidationException
98

@@ -56,18 +55,6 @@ def test_random_counterfactual_explanations_output(self, desired_class, sample_c
5655
assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0]
5756
assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
5857

59-
self.exp.serialize_explainer("random.pkl")
60-
new_exp = ExplainerBase.deserialize_explainer("random.pkl")
61-
62-
assert new_exp is not None
63-
counterfactual_explanations = new_exp.generate_counterfactuals(
64-
query_instances=sample_custom_query_1, desired_class=desired_class,
65-
total_CFs=total_CFs)
66-
67-
assert counterfactual_explanations is not None
68-
assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0]
69-
assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
70-
7158
# When invalid desired_class is given
7259
@pytest.mark.parametrize(("desired_class", "desired_range", "total_CFs", "features_to_vary", "permitted_range"),
7360
[(7, None, 3, "all", None)])

tests/test_dice_interface/test_explainer_base.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,22 @@ def test_desired_class(
256256
assert all(ans.cf_examples_list[0].final_cfs_df_sparse[exp.data_interface.outcome_name].values ==
257257
[desired_class] * 2)
258258

259+
exp.serialize_explainer(method + '.pkl')
260+
new_exp = ExplainerBase.deserialize_explainer(method + '.pkl')
261+
262+
ans = new_exp.generate_counterfactuals(query_instances=sample_custom_query_2,
263+
features_to_vary='all',
264+
total_CFs=2, desired_class=desired_class,
265+
proximity_weight=0.2, sparsity_weight=0.2,
266+
diversity_weight=5.0,
267+
categorical_penalty=0.1,
268+
permitted_range=None)
269+
if method != 'kdtree':
270+
assert all(ans.cf_examples_list[0].final_cfs_df[new_exp.data_interface.outcome_name].values == [desired_class] * 2)
271+
else:
272+
assert all(ans.cf_examples_list[0].final_cfs_df_sparse[new_exp.data_interface.outcome_name].values ==
273+
[desired_class] * 2)
274+
259275
@pytest.mark.parametrize(("desired_class", "total_CFs", "permitted_range"),
260276
[(1, 1, {'Numerical': [10, 150]})])
261277
def test_permitted_range(
@@ -349,6 +365,29 @@ def test_desired_class(
349365
[desired_class] * total_CFs)
350366
assert all(i == desired_class for i in exp.cfs_preds)
351367

368+
exp.serialize_explainer(method + '.pkl')
369+
new_exp = ExplainerBase.deserialize_explainer(method + '.pkl')
370+
371+
if method != 'genetic':
372+
ans = new_exp.generate_counterfactuals(
373+
query_instances=sample_custom_query_2,
374+
total_CFs=total_CFs, desired_class=desired_class)
375+
else:
376+
ans = new_exp.generate_counterfactuals(
377+
query_instances=sample_custom_query_2,
378+
total_CFs=total_CFs, desired_class=desired_class,
379+
initialization=genetic_initialization)
380+
381+
assert ans is not None
382+
if method != 'kdtree':
383+
assert all(
384+
ans.cf_examples_list[0].final_cfs_df[new_exp.data_interface.outcome_name].values == [desired_class] * total_CFs)
385+
else:
386+
assert all(
387+
ans.cf_examples_list[0].final_cfs_df_sparse[new_exp.data_interface.outcome_name].values ==
388+
[desired_class] * total_CFs)
389+
assert all(i == desired_class for i in new_exp.cfs_preds)
390+
352391
# When no elements in the desired_class are present in the training data
353392
@pytest.mark.parametrize(("desired_class", "total_CFs"), [(100, 3), ('opposite', 3)])
354393
def test_unsupported_multiclass(
@@ -422,6 +461,16 @@ def test_numeric_categories(self, desired_range, method, create_housing_data):
422461

423462
assert cf_explanation is not None
424463

464+
exp.serialize_explainer("explainer.pkl")
465+
new_exp = ExplainerBase.deserialize_explainer("explainer.pkl")
466+
467+
cf_explanation = new_exp.generate_counterfactuals(
468+
query_instances=x_test.iloc[0:1],
469+
total_CFs=10,
470+
desired_range=desired_range)
471+
472+
assert cf_explanation is not None
473+
425474

426475
class TestExplainerBase:
427476

0 commit comments

Comments
 (0)