Skip to content

Commit 1651751

Browse files
authored
Refactoring deep learning models to work with any explainer (#261)
* refactoring pytorch model to work with any other method * updated pytorch explainer and model to work with all methods Signed-off-by: Amit Sharma <amit_sharma@live.com> * updated tf methods to work with transformers Signed-off-by: Amit Sharma <amit_sharma@live.com> * added torch as a test dependency Signed-off-by: Amit Sharma <amit_sharma@live.com> * fixed random seed to have a fixed pytorch model Signed-off-by: Amit Sharma <amit_sharma@live.com>
1 parent 2e9d093 commit 1651751

32 files changed

Lines changed: 490 additions & 484 deletions

dice_ml/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class SamplingStrategy:
1414
Random = 'random'
1515
Genetic = 'genetic'
1616
KdTree = 'kdtree'
17+
Gradient = 'gradient'
1718

1819

1920
class ModelTypes:

dice_ml/data_interfaces/base_data_interface.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
from abc import ABC, abstractmethod
44

5+
from dice_ml.utils.exception import (SystemException,
6+
UserConfigValidationException)
7+
58

69
class _BaseData(ABC):
710

@@ -27,6 +30,46 @@ def set_continuous_feature_indexes(self, query_instance):
2730
self.continuous_feature_indexes = [query_instance.columns.get_loc(name) for name in
2831
self.continuous_feature_names]
2932

33+
def check_features_to_vary(self, features_to_vary):
34+
if features_to_vary is not None and features_to_vary != 'all':
35+
not_training_features = set(features_to_vary) - set(self.feature_names)
36+
if len(not_training_features) > 0:
37+
raise UserConfigValidationException("Got features {0} which are not present in training data".format(
38+
not_training_features))
39+
40+
def check_permitted_range(self, permitted_range):
41+
if permitted_range is not None:
42+
permitted_range_features = list(permitted_range)
43+
not_training_features = set(permitted_range_features) - set(self.feature_names)
44+
if len(not_training_features) > 0:
45+
raise UserConfigValidationException("Got features {0} which are not present in training data".format(
46+
not_training_features))
47+
48+
for feature in permitted_range_features:
49+
if feature in self.categorical_feature_names:
50+
train_categories = self.permitted_range[feature]
51+
for test_category in permitted_range[feature]:
52+
if test_category not in train_categories:
53+
raise UserConfigValidationException(
54+
'The category {0} does not occur in the training data for feature {1}.'
55+
' Allowed categories are {2}'.format(test_category, feature, train_categories))
56+
57+
def _validate_and_set_permitted_range(self, params, features_dict=None):
58+
"""Validate and set the dictionary of permitted ranges for continuous features."""
59+
input_permitted_range = None
60+
if 'permitted_range' in params:
61+
input_permitted_range = params['permitted_range']
62+
63+
if not hasattr(self, 'feature_names'):
64+
raise SystemException('Feature names not correctly set in public data interface')
65+
66+
for input_permitted_range_feature_name in input_permitted_range:
67+
if input_permitted_range_feature_name not in self.feature_names:
68+
raise UserConfigValidationException(
69+
"permitted_range contains some feature names which are not part of columns in dataframe"
70+
)
71+
self.permitted_range, _ = self.get_features_range(input_permitted_range, features_dict)
72+
3073
@abstractmethod
3174
def __init__(self, params):
3275
"""The init method needs to be implemented by the inherting classes."""

dice_ml/data_interfaces/private_data_interface.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import collections
44
import logging
55
import sys
6+
from collections import defaultdict
67

78
import numpy as np
89
import pandas as pd
@@ -46,21 +47,18 @@ def __init__(self, params):
4647
self._validate_and_set_type_and_precision(params=params)
4748

4849
self.continuous_feature_names = []
49-
self.permitted_range = {}
5050
self.categorical_feature_names = []
5151
self.categorical_levels = {}
5252

5353
for feature in features_dict:
5454
if type(features_dict[feature][0]) is int: # continuous feature
5555
self.continuous_feature_names.append(feature)
56-
self.permitted_range[feature] = features_dict[feature]
5756
else:
5857
self.categorical_feature_names.append(feature)
5958
self.categorical_levels[feature] = features_dict[feature]
6059

6160
self._validate_and_set_mad(params=params)
62-
63-
# self.continuous_feature_names + self.categorical_feature_names
61+
self._validate_and_set_permitted_range(params=params, features_dict=features_dict)
6462
self.feature_names = list(features_dict.keys())
6563

6664
self.continuous_feature_indexes = [list(features_dict.keys()).index(
@@ -73,20 +71,6 @@ def __init__(self, params):
7371
if feature_name not in self.type_and_precision:
7472
self.type_and_precision[feature_name] = 'int'
7573

76-
# # Initializing a label encoder to obtain label-encoded values for categorical variables
77-
# self.labelencoder = {}
78-
#
79-
# self.label_encoded_data = {}
80-
#
81-
# for column in self.categorical_feature_names:
82-
# self.labelencoder[column] = LabelEncoder()
83-
# self.label_encoded_data[column] = \
84-
# self.labelencoder[column].fit_transform(self.categorical_levels[column])
85-
86-
# self.max_range = -np.inf
87-
# for feature in self.continuous_feature_names:
88-
# self.max_range = max(self.max_range, self.permitted_range[feature][1])
89-
9074
self._validate_and_set_data_name(params=params)
9175

9276
def _validate_and_set_type_and_precision(self, params):
@@ -176,7 +160,22 @@ def get_valid_mads(self, normalized=False, display_warnings=False, return_mads=T
176160
if return_mads:
177161
return mads
178162

179-
def create_ohe_params(self):
163+
def get_features_range(self, permitted_range_input=None, features_dict=None):
164+
ranges = {}
165+
# Getting default ranges based on the dataset
166+
for feature in features_dict:
167+
if type(features_dict[feature][0]) is int: # continuous feature
168+
ranges[feature] = features_dict[feature]
169+
else:
170+
ranges[feature] = features_dict[feature]
171+
feature_ranges_orig = ranges.copy()
172+
# Overwriting the ranges for a feature if input provided
173+
if permitted_range_input is not None:
174+
for feature_name, feature_range in permitted_range_input.items():
175+
ranges[feature_name] = feature_range
176+
return ranges, feature_ranges_orig
177+
178+
def create_ohe_params(self, one_hot_encoded_data=None):
180179
if len(self.categorical_feature_names) > 0:
181180
# simulating sklearn's one-hot-encoding
182181
# continuous features on the left
@@ -265,16 +264,22 @@ def from_dummies(self, data, prefix_sep='_'):
265264
out.drop(cols, axis=1, inplace=True)
266265
return out
267266

268-
def get_decimal_precisions(self):
267+
def get_decimal_precisions(self, output_type="list"):
269268
""""Gets the precision of continuous features in the data."""
269+
precisions_dict = defaultdict(int)
270270
precisions = [0]*len(self.continuous_feature_names)
271271
for ix, feature_name in enumerate(self.continuous_feature_names):
272272
type_prec = self.type_and_precision[feature_name]
273273
if type_prec == 'int':
274-
precisions[ix] = 0
274+
prec = 0
275275
else:
276-
precisions[ix] = self.type_and_precision[feature_name][1]
277-
return precisions
276+
prec = self.type_and_precision[feature_name][1]
277+
precisions[ix] = prec
278+
precisions_dict[feature_name] = prec
279+
if output_type == "list":
280+
return precisions
281+
elif output_type == "dict":
282+
return precisions_dict
278283

279284
def get_decoded_data(self, data, encoding='one-hot'):
280285
"""Gets the original data from encoded data."""
@@ -284,11 +289,11 @@ def get_decoded_data(self, data, encoding='one-hot'):
284289
index = [i for i in range(0, len(data))]
285290
if encoding == 'one-hot':
286291
if isinstance(data, pd.DataFrame):
287-
return self.from_dummies(data)
292+
return data
288293
elif isinstance(data, np.ndarray):
289294
data = pd.DataFrame(data=data, index=index,
290295
columns=self.ohe_encoded_feature_names)
291-
return self.from_dummies(data)
296+
return data
292297
else:
293298
raise ValueError("data should be a pandas dataframe or a numpy array")
294299

@@ -347,7 +352,8 @@ def get_ohe_min_max_normalized_data(self, query_instance):
347352
"""Transforms query_instance into one-hot-encoded and min-max normalized data. query_instance should be a dict,
348353
a dataframe, a list, or a list of dicts"""
349354
query_instance = self.prepare_query_instance(query_instance)
350-
temp = self.ohe_base_df.append(query_instance, ignore_index=True, sort=False)
355+
ohe_base_df = self.prepare_df_for_ohe_encoding()
356+
temp = ohe_base_df.append(query_instance, ignore_index=True, sort=False)
351357
temp = self.one_hot_encode_data(temp)
352358
temp = temp.tail(query_instance.shape[0]).reset_index(drop=True)
353359
# returns a pandas dataframe
@@ -356,7 +362,7 @@ def get_ohe_min_max_normalized_data(self, query_instance):
356362
def get_inverse_ohe_min_max_normalized_data(self, transformed_data):
357363
"""Transforms one-hot-encoded and min-max normalized data into raw user-fed data format. transformed_data
358364
should be a dataframe or an array"""
359-
raw_data = self.get_decoded_data(transformed_data, encoding='one-hot')
365+
raw_data = self.from_dummies(transformed_data)
360366
raw_data = self.de_normalize_data(raw_data)
361367
precisions = self.get_decimal_precisions()
362368
for ix, feature in enumerate(self.continuous_feature_names):

dice_ml/data_interfaces/public_data_interface.py

Lines changed: 9 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -54,28 +54,7 @@ def __init__(self, params):
5454
self.categorical_feature_names,
5555
self.continuous_feature_names)
5656

57-
# should move the below snippet to gradient based dice interfaces
58-
# self.one_hot_encoded_data = self.one_hot_encode_data(self.data_df)
59-
# self.ohe_encoded_feature_names = [x for x in self.one_hot_encoded_data.columns.tolist(
60-
# ) if x not in np.array([self.outcome_name])]
61-
62-
# should move the below snippet to model agnostic dice interfaces
63-
# # Initializing a label encoder to obtain label-encoded values for categorical variables
64-
# self.labelencoder = {}
65-
#
66-
# self.label_encoded_data = self.data_df.copy()
67-
#
68-
# for column in self.categorical_feature_names:
69-
# self.labelencoder[column] = LabelEncoder()
70-
# self.label_encoded_data[column] = self.labelencoder[column].fit_transform(self.data_df[column])
71-
7257
self._validate_and_set_permitted_range(params=params)
73-
74-
# should move the below snippet to model agnostic dice interfaces
75-
# self.max_range = -np.inf
76-
# for feature in self.continuous_feature_names:
77-
# self.max_range = max(self.max_range, self.permitted_range[feature][1])
78-
7958
self._validate_and_set_data_name(params=params)
8059

8160
def _validate_and_set_dataframe(self, params):
@@ -122,22 +101,6 @@ def _validate_and_set_continuous_features_precision(self, params):
122101
else:
123102
self.continuous_features_precision = None
124103

125-
def _validate_and_set_permitted_range(self, params):
126-
"""Validate and set the dictionary of permitted ranges for continuous features."""
127-
input_permitted_range = None
128-
if 'permitted_range' in params:
129-
input_permitted_range = params['permitted_range']
130-
131-
if not hasattr(self, 'feature_names'):
132-
raise SystemException('Feature names not correctly set in public data interface')
133-
134-
for input_permitted_range_feature_name in input_permitted_range:
135-
if input_permitted_range_feature_name not in self.feature_names:
136-
raise UserConfigValidationException(
137-
"permitted_range contains some feature names which are not part of columns in dataframe"
138-
)
139-
self.permitted_range, _ = self.get_features_range(input_permitted_range)
140-
141104
def _set_feature_dtypes(self, data_df, categorical_feature_names,
142105
continuous_feature_names):
143106
"""Set the correct type of each feature column."""
@@ -157,38 +120,7 @@ def _set_feature_dtypes(self, data_df, categorical_feature_names,
157120
np.int32)
158121
return data_df
159122

160-
def check_features_to_vary(self, features_to_vary):
161-
if features_to_vary is not None and features_to_vary != 'all':
162-
not_training_features = set(features_to_vary) - set(self.feature_names)
163-
if len(not_training_features) > 0:
164-
raise UserConfigValidationException("Got features {0} which are not present in training data".format(
165-
not_training_features))
166-
167-
def check_permitted_range(self, permitted_range):
168-
if permitted_range is not None:
169-
permitted_range_features = list(permitted_range)
170-
not_training_features = set(permitted_range_features) - set(self.feature_names)
171-
if len(not_training_features) > 0:
172-
raise UserConfigValidationException("Got features {0} which are not present in training data".format(
173-
not_training_features))
174-
175-
for feature in permitted_range_features:
176-
if feature in self.categorical_feature_names:
177-
train_categories = self.permitted_range[feature]
178-
for test_category in permitted_range[feature]:
179-
if test_category not in train_categories:
180-
raise UserConfigValidationException(
181-
'The category {0} does not occur in the training data for feature {1}.'
182-
' Allowed categories are {2}'.format(test_category, feature, train_categories))
183-
184-
def check_mad_validity(self, feature_weights):
185-
"""checks feature MAD validity and throw warnings.
186-
TODO: add comments as to where this is used if this function is necessary, else remove.
187-
"""
188-
if feature_weights == "inverse_mad":
189-
self.get_valid_mads(display_warnings=True, return_mads=False)
190-
191-
def get_features_range(self, permitted_range_input=None):
123+
def get_features_range(self, permitted_range_input=None, features_dict=None):
192124
ranges = {}
193125
# Getting default ranges based on the dataset
194126
for feature_name in self.continuous_feature_names:
@@ -307,25 +239,6 @@ def get_minx_maxx(self, normalized=True):
307239
minx[0][idx] = self.permitted_range[feature_name][0]
308240
maxx[0][idx] = self.permitted_range[feature_name][1]
309241
return minx, maxx
310-
# if encoding=='one-hot':
311-
# minx = np.array([[0.0] * len(self.ohe_encoded_feature_names)])
312-
# maxx = np.array([[1.0] * len(self.ohe_encoded_feature_names)])
313-
314-
# for idx, feature_name in enumerate(self.continuous_feature_names):
315-
# max_value = self.train_df[feature_name].max()
316-
# min_value = self.train_df[feature_name].min()
317-
318-
# if normalized:
319-
# minx[0][idx] = (self.permitted_range[feature_name]
320-
# [0] - min_value) / (max_value - min_value)
321-
# maxx[0][idx] = (self.permitted_range[feature_name]
322-
# [1] - min_value) / (max_value - min_value)
323-
# else:
324-
# minx[0][idx] = self.permitted_range[feature_name][0]
325-
# maxx[0][idx] = self.permitted_range[feature_name][1]
326-
# else:
327-
# minx = np.array([[0.0] * len(self.feature_names)])
328-
# maxx = np.array([[1.0] * len(self.feature_names)])
329242

330243
def get_mads(self, normalized=False):
331244
"""Computes Median Absolute Deviation of features."""
@@ -370,24 +283,17 @@ def get_quantiles_from_training_data(self, quantile=0.05, normalized=False):
370283
list(set(normalized_train_df[feature].tolist())))), quantile)
371284
return quantiles
372285

373-
def create_ohe_params(self):
286+
def create_ohe_params(self, one_hot_encoded_data):
374287
if len(self.categorical_feature_names) > 0:
375-
one_hot_encoded_data = self.one_hot_encode_data(self.data_df)
376288
self.ohe_encoded_feature_names = [x for x in one_hot_encoded_data.columns.tolist(
377289
) if x not in np.array([self.outcome_name])]
378290
else:
379291
# one-hot-encoded data is same as original data if there is no categorical features.
380292
self.ohe_encoded_feature_names = [feat for feat in self.feature_names]
381293

382-
# base dataframe for doing one-hot-encoding
383-
# ohe_encoded_feature_names and ohe_base_df are created (and stored as data class's parameters)
384-
# when get_data_params_for_gradient_dice() is called from gradient-based DiCE explainers
385-
self.ohe_base_df = self.prepare_df_for_ohe_encoding()
386-
387294
def get_data_params_for_gradient_dice(self):
388295
"""Gets all data related params for DiCE."""
389296

390-
self.create_ohe_params()
391297
minx, maxx = self.get_minx_maxx(normalized=True)
392298

393299
# get the column indexes of categorical and continuous features after one-hot-encoding
@@ -497,11 +403,11 @@ def get_decoded_data(self, data, encoding='one-hot'):
497403
index = [i for i in range(0, len(data))]
498404
if encoding == 'one-hot':
499405
if isinstance(data, pd.DataFrame):
500-
return self.from_dummies(data)
406+
return data
501407
elif isinstance(data, np.ndarray):
502408
data = pd.DataFrame(data=data, index=index,
503409
columns=self.ohe_encoded_feature_names)
504-
return self.from_dummies(data)
410+
return data
505411
else:
506412
raise ValueError("data should be a pandas dataframe or a numpy array")
507413

@@ -560,35 +466,21 @@ def prepare_query_instance(self, query_instance):
560466
self.continuous_feature_names)
561467
return test
562468

563-
# TODO: create a new method, get_LE_min_max_normalized_data() to get label-encoded and normalized data. Keep this
564-
# method only for converting query_instance to pd.DataFrame
565-
# if encoding == 'label':
566-
# for column in self.categorical_feature_names:
567-
# test[column] = self.labelencoder[column].transform(test[column])
568-
# return self.normalize_data(test, encoding)
569-
#
570-
# elif encoding == 'one-hot':
571-
# temp = self.prepare_df_for_encoding()
572-
# temp = temp.append(test, ignore_index=True, sort=False)
573-
# temp = self.one_hot_encode_data(temp)
574-
# temp = self.normalize_data(temp)
575-
#
576-
# return temp.tail(test.shape[0]).reset_index(drop=True)
577-
578469
def get_ohe_min_max_normalized_data(self, query_instance):
579470
"""Transforms query_instance into one-hot-encoded and min-max normalized data. query_instance should be a dict,
580471
a dataframe, a list, or a list of dicts"""
581472
query_instance = self.prepare_query_instance(query_instance)
582-
temp = self.ohe_base_df.append(query_instance, ignore_index=True, sort=False)
473+
ohe_base_df = self.prepare_df_for_ohe_encoding()
474+
temp = ohe_base_df.append(query_instance, ignore_index=True, sort=False)
583475
temp = self.one_hot_encode_data(temp)
584476
temp = temp.tail(query_instance.shape[0]).reset_index(drop=True)
585-
# returns a pandas dataframe
586-
return self.normalize_data(temp)
477+
# returns a pandas dataframe with all numeric values
478+
return self.normalize_data(temp).apply(pd.to_numeric)
587479

588480
def get_inverse_ohe_min_max_normalized_data(self, transformed_data):
589481
"""Transforms one-hot-encoded and min-max normalized data into raw user-fed data format. transformed_data
590482
should be a dataframe or an array"""
591-
raw_data = self.get_decoded_data(transformed_data, encoding='one-hot')
483+
raw_data = self.from_dummies(transformed_data)
592484
raw_data = self.de_normalize_data(raw_data)
593485
precisions = self.get_decimal_precisions()
594486
for ix, feature in enumerate(self.continuous_feature_names):

0 commit comments

Comments
 (0)