Skip to content

Commit 5e70ef4

Browse files
authored
Merge pull request #289 from interpretml/gaugup/SerializeDeserializeExplainers
Add capability to serialize and de-serialize dice-ml explainers
2 parents 6b35253 + d19a916 commit 5e70ef4

2 files changed

Lines changed: 65 additions & 0 deletions

File tree

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Subclasses implement interfaces for different ML frameworks such as TensorFlow or PyTorch.
33
All methods are in dice_ml.explainer_interfaces"""
44

5+
import pickle
56
from abc import ABC, abstractmethod
67
from collections.abc import Iterable
78

@@ -805,3 +806,17 @@ def _check_any_counterfactuals_computed(self, cf_examples_arr):
805806
if no_cf_generated:
806807
raise UserConfigValidationException(
807808
"No counterfactuals found for any of the query points! Kindly check your configuration.")
809+
810+
def serialize_explainer(self, path):
811+
"""Serialize the explainer to the file specified by path."""
812+
with open(path, "wb") as pickle_file:
813+
pickle.dump(self, pickle_file)
814+
815+
@staticmethod
816+
def deserialize_explainer(path):
817+
"""Reload the explainer into the memory by reading the file specified by path."""
818+
deserialized_exp = None
819+
with open(path, "rb") as pickle_file:
820+
deserialized_exp = pickle.load(pickle_file)
821+
822+
return deserialized_exp

tests/test_dice_interface/test_explainer_base.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,22 @@ def test_desired_class(
256256
assert all(ans.cf_examples_list[0].final_cfs_df_sparse[exp.data_interface.outcome_name].values ==
257257
[desired_class] * 2)
258258

259+
exp.serialize_explainer(method + '.pkl')
260+
new_exp = ExplainerBase.deserialize_explainer(method + '.pkl')
261+
262+
ans = new_exp.generate_counterfactuals(query_instances=sample_custom_query_2,
263+
features_to_vary='all',
264+
total_CFs=2, desired_class=desired_class,
265+
proximity_weight=0.2, sparsity_weight=0.2,
266+
diversity_weight=5.0,
267+
categorical_penalty=0.1,
268+
permitted_range=None)
269+
if method != 'kdtree':
270+
assert all(ans.cf_examples_list[0].final_cfs_df[new_exp.data_interface.outcome_name].values == [desired_class] * 2)
271+
else:
272+
assert all(ans.cf_examples_list[0].final_cfs_df_sparse[new_exp.data_interface.outcome_name].values ==
273+
[desired_class] * 2)
274+
259275
@pytest.mark.parametrize(("desired_class", "total_CFs", "permitted_range"),
260276
[(1, 1, {'Numerical': [10, 150]})])
261277
def test_permitted_range(
@@ -349,6 +365,30 @@ def test_desired_class(
349365
[desired_class] * total_CFs)
350366
assert all(i == desired_class for i in exp.cfs_preds)
351367

368+
exp.serialize_explainer(method + '.pkl')
369+
new_exp = ExplainerBase.deserialize_explainer(method + '.pkl')
370+
371+
if method != 'genetic':
372+
ans = new_exp.generate_counterfactuals(
373+
query_instances=sample_custom_query_2,
374+
total_CFs=total_CFs, desired_class=desired_class)
375+
else:
376+
ans = new_exp.generate_counterfactuals(
377+
query_instances=sample_custom_query_2,
378+
total_CFs=total_CFs, desired_class=desired_class,
379+
initialization=genetic_initialization)
380+
381+
assert ans is not None
382+
if method != 'kdtree':
383+
assert all(
384+
ans.cf_examples_list[0].final_cfs_df[
385+
new_exp.data_interface.outcome_name].values == [desired_class] * total_CFs)
386+
else:
387+
assert all(
388+
ans.cf_examples_list[0].final_cfs_df_sparse[new_exp.data_interface.outcome_name].values ==
389+
[desired_class] * total_CFs)
390+
assert all(i == desired_class for i in new_exp.cfs_preds)
391+
352392
# When no elements in the desired_class are present in the training data
353393
@pytest.mark.parametrize(("desired_class", "total_CFs"), [(100, 3), ('opposite', 3)])
354394
def test_unsupported_multiclass(
@@ -422,6 +462,16 @@ def test_numeric_categories(self, desired_range, method, create_housing_data):
422462

423463
assert cf_explanation is not None
424464

465+
exp.serialize_explainer("explainer.pkl")
466+
new_exp = ExplainerBase.deserialize_explainer("explainer.pkl")
467+
468+
cf_explanation = new_exp.generate_counterfactuals(
469+
query_instances=x_test.iloc[0:1],
470+
total_CFs=10,
471+
desired_range=desired_range)
472+
473+
assert cf_explanation is not None
474+
425475

426476
class TestExplainerBase:
427477

0 commit comments

Comments
 (0)