Skip to content

Commit 21bb3b4

Browse files
committed
Merge branch 'master' into gaugup/SerializeDeserializeExplainers
2 parents 4e8f4c2 + 6b35253 commit 21bb3b4

7 files changed

Lines changed: 135 additions & 21 deletions

File tree

dice_ml/data_interfaces/public_data_interface.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,10 @@ def __init__(self, params):
4949
name) for name in self.categorical_feature_names if name in self.data_df]
5050

5151
self._validate_and_set_continuous_features_precision(params=params)
52-
53-
if len(self.categorical_feature_names) > 0:
54-
for feature in self.categorical_feature_names:
55-
self.data_df[feature] = self.data_df[feature].apply(str)
56-
self.data_df[self.categorical_feature_names] = self.data_df[self.categorical_feature_names].astype(
57-
'category')
58-
59-
if len(self.continuous_feature_names) > 0:
60-
for feature in self.continuous_feature_names:
61-
if self.get_data_type(feature) == 'float':
62-
self.data_df[feature] = self.data_df[feature].astype(
63-
np.float32)
64-
else:
65-
self.data_df[feature] = self.data_df[feature].astype(
66-
np.int32)
52+
self.data_df = self._set_feature_dtypes(
53+
self.data_df,
54+
self.categorical_feature_names,
55+
self.continuous_feature_names)
6756

6857
# should move the below snippet to gradient based dice interfaces
6958
# self.one_hot_encoded_data = self.one_hot_encode_data(self.data_df)
@@ -149,6 +138,25 @@ def _validate_and_set_permitted_range(self, params):
149138
)
150139
self.permitted_range, _ = self.get_features_range(input_permitted_range)
151140

141+
def _set_feature_dtypes(self, data_df, categorical_feature_names,
142+
continuous_feature_names):
143+
"""Set the correct type of each feature column."""
144+
if len(categorical_feature_names) > 0:
145+
for feature in categorical_feature_names:
146+
data_df[feature] = data_df[feature].apply(str)
147+
data_df[categorical_feature_names] = data_df[categorical_feature_names].astype(
148+
'category')
149+
150+
if len(continuous_feature_names) > 0:
151+
for feature in continuous_feature_names:
152+
if self.get_data_type(feature) == 'float':
153+
data_df[feature] = data_df[feature].astype(
154+
np.float32)
155+
else:
156+
data_df[feature] = data_df[feature].astype(
157+
np.int32)
158+
return data_df
159+
152160
def check_features_to_vary(self, features_to_vary):
153161
if features_to_vary is not None and features_to_vary != 'all':
154162
not_training_features = set(features_to_vary) - set(self.feature_names)
@@ -546,6 +554,10 @@ def prepare_query_instance(self, query_instance):
546554
raise ValueError("Query instance should be a dict, a pandas dataframe, a list, or a list of dicts")
547555

548556
test = test.reset_index(drop=True)
557+
# encode categorical and numerical columns
558+
test = self._set_feature_dtypes(test,
559+
self.categorical_feature_names,
560+
self.continuous_feature_names)
549561
return test
550562

551563
# TODO: create a new method, get_LE_min_max_normalized_data() to get label-encoded and normalized data. Keep this

dice_ml/explainer_interfaces/dice_KD.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,10 @@ def _generate_counterfactuals(self, query_instance, total_CFs, desired_range=Non
8282

8383
# Prepares user defined query_instance for DiCE.
8484
query_instance_orig = query_instance.copy()
85-
query_instance = self.data_interface.prepare_query_instance(query_instance=query_instance)
85+
query_instance_orig = self.data_interface.prepare_query_instance(
86+
query_instance=query_instance_orig)
87+
query_instance = self.data_interface.prepare_query_instance(
88+
query_instance=query_instance)
8689

8790
# find the predicted value of query_instance
8891
test_pred = self.predict_fn(query_instance)[0]
@@ -103,7 +106,6 @@ def _generate_counterfactuals(self, query_instance, total_CFs, desired_range=Non
103106
# Partitioned dataset and KD Tree for each class (binary) of the dataset
104107
self.dataset_with_predictions, self.KD_tree, self.predictions = \
105108
self.build_KD_tree(data_df_copy, desired_range, desired_class, self.predicted_outcome_name)
106-
107109
query_instance, cfs_preds = self.find_counterfactuals(data_df_copy,
108110
query_instance, query_instance_orig,
109111
desired_range,
@@ -224,7 +226,6 @@ def find_counterfactuals(self, data_df_copy, query_instance, query_instance_orig
224226
for col in pd.get_dummies(data_df_copy[self.data_interface.feature_names]).columns:
225227
if col not in query_instance_df_dummies.columns:
226228
query_instance_df_dummies[col] = 0
227-
228229
self.final_cfs, cfs_preds = self.vary_valid(query_instance_df_dummies,
229230
total_CFs,
230231
features_to_vary,

dice_ml/explainer_interfaces/dice_genetic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,10 @@ def _generate_counterfactuals(self, query_instance, total_CFs, initialization="k
268268

269269
# Prepares user defined query_instance for DiCE.
270270
query_instance_orig = query_instance
271-
query_instance = self.data_interface.prepare_query_instance(query_instance=query_instance)
271+
query_instance_orig = self.data_interface.prepare_query_instance(
272+
query_instance=query_instance_orig)
273+
query_instance = self.data_interface.prepare_query_instance(
274+
query_instance=query_instance)
272275
query_instance = self.label_encode(query_instance)
273276
query_instance = np.array(query_instance.values[0])
274277
self.x1 = query_instance

dice_ml/utils/helpers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
import pandas as pd
99
from sklearn.model_selection import train_test_split
10-
# for data transformations
1110
from sklearn.preprocessing import FunctionTransformer
1211

1312
import dice_ml
@@ -129,6 +128,13 @@ def get_custom_dataset_modelpath_pipeline():
129128
return modelpath
130129

131130

131+
def get_custom_vars_dataset_modelpath_pipeline():
132+
pkg_path = dice_ml.__path__[0]
133+
model_ext = '.sav'
134+
modelpath = os.path.join(pkg_path, 'utils', 'sample_trained_models', 'custom_vars'+model_ext)
135+
return modelpath
136+
137+
132138
def get_custom_dataset_modelpath_pipeline_binary():
133139
pkg_path = dice_ml.__path__[0]
134140
model_ext = '.sav'
@@ -168,7 +174,6 @@ def get_base_gen_cf_initialization(data_interface, encoded_size, cont_minx, cont
168174
wm1, wm2, wm3, learning_rate):
169175
# Dice Imports - TODO: keep this method for VAE as a spearate module or move it to feasible_base_vae.py.
170176
# Check dependencies.
171-
# Pytorch
172177
from torch import optim
173178

174179
from dice_ml.utils.sample_architecture.vae_model import CF_VAE
40.9 KB
Binary file not shown.

tests/conftest.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
import pickle
12
from collections import OrderedDict
23

34
import pandas as pd
45
import pytest
6+
from sklearn.compose import ColumnTransformer
57
from sklearn.datasets import fetch_california_housing, load_iris
8+
from sklearn.ensemble import RandomForestClassifier
9+
from sklearn.impute import SimpleImputer
610
from sklearn.model_selection import train_test_split
11+
from sklearn.pipeline import Pipeline
12+
from sklearn.preprocessing import OneHotEncoder, StandardScaler
713

814
import dice_ml
915
from dice_ml.utils import helpers
@@ -110,6 +116,33 @@ def private_data_object():
110116
return dice_ml.Data(features=features_dict, outcome_name='income')
111117

112118

119+
@pytest.fixture()
120+
def load_custom_vars_testing_dataset():
121+
data = [['a', 0, 10, 0], ['b', 1, 10000, 0], ['c', 0, 14, 0], ['a', 2, 88, 0], ['c', 1, 14, 0]]
122+
return pd.DataFrame(data, columns=['Categorical', 'CategoricalNum', 'Numerical', 'Outcome'])
123+
124+
125+
@pytest.fixture()
126+
def _save_custom_vars_dataset_model():
127+
numeric_trans = Pipeline(steps=[('imputer', SimpleImputer(strategy='median')),
128+
('scaler', StandardScaler())])
129+
cat_trans = Pipeline(steps=[('imputer',
130+
SimpleImputer(fill_value='missing',
131+
strategy='constant')),
132+
('onehot', OneHotEncoder(handle_unknown='ignore'))])
133+
transformations = ColumnTransformer(transformers=[('num', numeric_trans,
134+
['Numerical']),
135+
('cat', cat_trans,
136+
pd.Index(['Categorical', 'CategoricalNum'], dtype='object'))])
137+
clf = Pipeline(steps=[('preprocessor', transformations),
138+
('regressor', RandomForestClassifier())])
139+
dataset = load_custom_vars_testing_dataset()
140+
model = clf.fit(dataset[["Categorical", "CategoricalNum", "Numerical"]],
141+
dataset["Outcome"])
142+
modelpath = helpers.get_custom_vars_dataset_modelpath_pipeline()
143+
pickle.dump(model, open(modelpath, 'wb'))
144+
145+
113146
@pytest.fixture()
114147
def sample_adultincome_query():
115148
"""
@@ -188,6 +221,14 @@ def sample_custom_query_10():
188221
)
189222

190223

224+
@pytest.fixture()
225+
def sample_custom_vars_query_1():
226+
"""
227+
Returns a sample query instance for the custom dataset
228+
"""
229+
return pd.DataFrame({'Categorical': ['a'], 'CategoricalNum': [0], 'Numerical': [25]})
230+
231+
191232
@pytest.fixture()
192233
def sample_counterfactual_example_dummy():
193234
"""

tests/test_dice_interface/test_dice_KD.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@ def KD_binary_classification_exp_object():
1818
return exp
1919

2020

21+
@pytest.fixture()
22+
def KD_binary_vars_classification_exp_object(load_custom_vars_testing_dataset):
23+
backend = 'sklearn'
24+
dataset = load_custom_vars_testing_dataset
25+
d = dice_ml.Data(dataframe=dataset, continuous_features=['Numerical'], outcome_name='Outcome')
26+
ML_modelpath = helpers.get_custom_vars_dataset_modelpath_pipeline()
27+
m = dice_ml.Model(model_path=ML_modelpath, backend=backend)
28+
exp = dice_ml.Dice(d, m, method='kdtree')
29+
return exp
30+
31+
2132
@pytest.fixture()
2233
def KD_multi_classification_exp_object():
2334
backend = 'sklearn'
@@ -194,3 +205,44 @@ def test_KD_tree_counterfactual_explanations_output(self, desired_range, sample_
194205
def test_zero_cfs(self, desired_class, desired_range, sample_custom_query_4, total_CFs):
195206
self.exp_regr._generate_counterfactuals(query_instance=sample_custom_query_4, total_CFs=total_CFs,
196207
desired_range=desired_range)
208+
209+
210+
class TestDiceKDBinaryVarsClassificationMethods:
211+
@pytest.fixture(autouse=True)
212+
def _initiate_exp_object(self, KD_binary_vars_classification_exp_object):
213+
self.exp = KD_binary_vars_classification_exp_object # explainer object
214+
self.data_df_copy = self.exp.data_interface.data_df.copy()
215+
216+
# When a query's feature value is not within the permitted range and the feature is not allowed to vary
217+
@pytest.mark.parametrize(("desired_range", "desired_class", "total_CFs", "features_to_vary", "permitted_range"),
218+
[(None, 0, 4, ['Numerical'], {'CategoricalNum': ['1', '2']})])
219+
def test_invalid_query_instance(self, desired_range, desired_class, sample_custom_vars_query_1, total_CFs,
220+
features_to_vary, permitted_range):
221+
self.exp.dataset_with_predictions, self.exp.KD_tree, self.exp.predictions = \
222+
self.exp.build_KD_tree(self.data_df_copy, desired_range, desired_class, self.exp.predicted_outcome_name)
223+
224+
with pytest.raises(ValueError, match="is outside the permitted range and isn't allowed to vary"):
225+
self.exp._generate_counterfactuals(query_instance=sample_custom_vars_query_1, total_CFs=total_CFs,
226+
features_to_vary=features_to_vary, permitted_range=permitted_range)
227+
228+
# Verifying the output of the KD tree
229+
@pytest.mark.parametrize(("desired_class", "total_CFs"), [(0, 1)])
230+
@pytest.mark.parametrize('posthoc_sparsity_algorithm', ['linear', 'binary', None])
231+
def test_KD_tree_output(self, desired_class, sample_custom_vars_query_1, total_CFs, posthoc_sparsity_algorithm):
232+
self.exp._generate_counterfactuals(query_instance=sample_custom_vars_query_1, desired_class=desired_class,
233+
total_CFs=total_CFs,
234+
posthoc_sparsity_algorithm=posthoc_sparsity_algorithm)
235+
self.exp.final_cfs_df.Numerical = self.exp.final_cfs_df.Numerical.astype(int)
236+
expected_output = self.exp.data_interface.data_df
237+
238+
assert all(self.exp.final_cfs_df.Numerical == expected_output.Numerical[0])
239+
assert all(self.exp.final_cfs_df.Categorical == expected_output.Categorical[0])
240+
241+
# Verifying the output of the KD tree
242+
@pytest.mark.parametrize(("desired_class", "total_CFs"), [(0, 1)])
243+
def test_KD_tree_counterfactual_explanations_output(self, desired_class, sample_custom_vars_query_1, total_CFs):
244+
counterfactual_explanations = self.exp.generate_counterfactuals(
245+
query_instances=sample_custom_vars_query_1, desired_class=desired_class,
246+
total_CFs=total_CFs)
247+
248+
assert counterfactual_explanations is not None

0 commit comments

Comments
 (0)