Skip to content

Commit eb4de88

Browse files
committed
[WIP] Add capability to serialize and de-serialize dice-ml explainers
Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>
1 parent 195b638 commit eb4de88

2 files changed

Lines changed: 28 additions & 0 deletions

File tree

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections.abc import Iterable
77

88
import numpy as np
9+
import pickle
910
import pandas as pd
1011
from sklearn.neighbors import KDTree
1112
from tqdm import tqdm
@@ -805,3 +806,17 @@ def _check_any_counterfactuals_computed(self, cf_examples_arr):
805806
if no_cf_generated:
806807
raise UserConfigValidationException(
807808
"No counterfactuals found for any of the query points! Kindly check your configuration.")
809+
810+
def serialize_explainer(self, path):
811+
"""Serialize the explainer to the file specified by path."""
812+
with open(path, "wb") as pickle_file:
813+
pickle.dump(self, pickle_file)
814+
815+
@staticmethod
816+
def deserialize_explainer(path):
817+
"""Reload the explainer into the memroy by reading the file specified by path."""
818+
deserialized_exp = None
819+
with open(path, "rb") as pickle_file:
820+
deserialized_exp = pickle.load(pickle_file)
821+
822+
return deserialized_exp

tests/test_dice_interface/test_dice_random.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import dice_ml
44
from dice_ml.counterfactual_explanations import CounterfactualExplanations
55
from dice_ml.diverse_counterfactuals import CounterfactualExamples
6+
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
67
from dice_ml.utils import helpers
78
from dice_ml.utils.exception import UserConfigValidationException
89

@@ -55,6 +56,18 @@ def test_random_counterfactual_explanations_output(self, desired_class, sample_c
5556
assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0]
5657
assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
5758

59+
self.exp.serialize_explainer("random.pkl")
60+
new_exp = ExplainerBase.deserialize_explainer("random.pkl")
61+
62+
assert new_exp is not None
63+
counterfactual_explanations = new_exp.generate_counterfactuals(
64+
query_instances=sample_custom_query_1, desired_class=desired_class,
65+
total_CFs=total_CFs)
66+
67+
assert counterfactual_explanations is not None
68+
assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0]
69+
assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
70+
5871
# When invalid desired_class is given
5972
@pytest.mark.parametrize("desired_class, desired_range, total_CFs, features_to_vary, permitted_range",
6073
[(7, None, 3, "all", None)])

0 commit comments

Comments
 (0)