Skip to content

Commit 2acac6b

Browse files
committed
Merge branch 'master' into gaugup/SerializeDeserializeExplainers
2 parents e56562a + edc5415 commit 2acac6b

16 files changed

Lines changed: 126 additions & 114 deletions

dice_ml/explainer_interfaces/dice_tensorflow1.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def do_cf_initializations(self, total_CFs, algorithm, features_to_vary):
225225
def predict_fn(self, input_instance):
226226
"""prediction function"""
227227
temp_preds = self.dice_sess.run(self.output_tensor, feed_dict={self.input_tensor: input_instance})
228-
return np.array([preds[(self.num_ouput_nodes-1):] for preds in temp_preds])
228+
return np.array([preds[(self.num_output_nodes-1):] for preds in temp_preds])
229229

230230
def predict_fn_for_sparsity(self, input_instance):
231231
"""prediction function for sparsity correction"""
@@ -239,22 +239,22 @@ def compute_yloss(self, method):
239239
if method == "l2_loss":
240240
temp_loss = tf.square(tf.subtract(
241241
self.model.get_output(self.cfs_frozen[i]), self.target_cf))
242-
temp_loss = temp_loss[:, (self.num_ouput_nodes-1):][0][0]
242+
temp_loss = temp_loss[:, (self.num_output_nodes-1):][0][0]
243243
elif method == "log_loss":
244244
temp_logits = tf.log(
245245
tf.divide(
246246
tf.abs(tf.subtract(self.model.get_output(self.cfs_frozen[i]), 0.000001)),
247247
tf.subtract(1.0, tf.abs(tf.subtract(self.model.get_output(
248248
self.cfs_frozen[i]), 0.000001)))))
249-
temp_logits = temp_logits[:, (self.num_ouput_nodes-1):]
249+
temp_logits = temp_logits[:, (self.num_output_nodes-1):]
250250
temp_loss = tf.nn.sigmoid_cross_entropy_with_logits(
251251
logits=temp_logits, labels=self.target_cf)[0][0]
252252
elif method == "hinge_loss":
253253
temp_logits = tf.log(
254254
tf.divide(
255255
tf.abs(tf.subtract(self.model.get_output(self.cfs_frozen[i]), 0.000001)),
256256
tf.subtract(1.0, tf.abs(tf.subtract(self.model.get_output(self.cfs_frozen[i]), 0.000001)))))
257-
temp_logits = temp_logits[:, (self.num_ouput_nodes-1):]
257+
temp_logits = temp_logits[:, (self.num_output_nodes-1):]
258258
temp_loss = tf.losses.hinge_loss(
259259
logits=temp_logits, labels=self.target_cf)
260260

requirements-linting.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ flake8-breakpoint
55
flake8-builtins==1.5.3
66
flake8-logging-format==0.6.0
77
flake8-nb==0.4.0
8+
flake8-pytest-style
89
isort
910
packaging

tests/conftest.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dice_ml.utils import helpers
1010

1111

12-
@pytest.fixture
12+
@pytest.fixture()
1313
def binary_classification_exp_object(method="random"):
1414
backend = 'sklearn'
1515
dataset = helpers.load_custom_testing_dataset_binary()
@@ -20,7 +20,7 @@ def binary_classification_exp_object(method="random"):
2020
return exp
2121

2222

23-
@pytest.fixture
23+
@pytest.fixture()
2424
def binary_classification_exp_object_out_of_order(method="random"):
2525
backend = 'sklearn'
2626
dataset = helpers.load_outcome_not_last_column_dataset()
@@ -31,7 +31,7 @@ def binary_classification_exp_object_out_of_order(method="random"):
3131
return exp
3232

3333

34-
@pytest.fixture
34+
@pytest.fixture()
3535
def multi_classification_exp_object(method="random"):
3636
backend = 'sklearn'
3737
dataset = helpers.load_custom_testing_dataset_multiclass()
@@ -42,7 +42,7 @@ def multi_classification_exp_object(method="random"):
4242
return exp
4343

4444

45-
@pytest.fixture
45+
@pytest.fixture()
4646
def regression_exp_object(method="random"):
4747
backend = 'sklearn'
4848
dataset = helpers.load_custom_testing_dataset_regression()
@@ -81,7 +81,7 @@ def sklearn_regression_model_interface():
8181
return m
8282

8383

84-
@pytest.fixture
84+
@pytest.fixture()
8585
def public_data_object():
8686
"""
8787
Returns a public data object for the adult income dataset
@@ -90,7 +90,7 @@ def public_data_object():
9090
return dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income')
9191

9292

93-
@pytest.fixture
93+
@pytest.fixture()
9494
def private_data_object():
9595
"""
9696
Returns a private data object containing meta information about the adult income dataset.
@@ -110,7 +110,7 @@ def private_data_object():
110110
return dice_ml.Data(features=features_dict, outcome_name='income')
111111

112112

113-
@pytest.fixture
113+
@pytest.fixture()
114114
def sample_adultincome_query():
115115
"""
116116
Returns a sample query instance for adult income dataset
@@ -119,63 +119,63 @@ def sample_adultincome_query():
119119
'race': 'White', 'gender': 'Female', 'hours_per_week': 45}
120120

121121

122-
@pytest.fixture
122+
@pytest.fixture()
123123
def sample_custom_query_1():
124124
"""
125125
Returns a sample query instance for the custom dataset
126126
"""
127127
return pd.DataFrame({'Categorical': ['a'], 'Numerical': [25]})
128128

129129

130-
@pytest.fixture
130+
@pytest.fixture()
131131
def sample_custom_query_2():
132132
"""
133133
Returns a sample query instance for the custom dataset
134134
"""
135135
return pd.DataFrame({'Categorical': ['b'], 'Numerical': [25]})
136136

137137

138-
@pytest.fixture
138+
@pytest.fixture()
139139
def sample_custom_query_3():
140140
"""
141141
Returns a sample query instance for the custom dataset
142142
"""
143143
return pd.DataFrame({'Categorical': ['d'], 'Numerical': [1000000]})
144144

145145

146-
@pytest.fixture
146+
@pytest.fixture()
147147
def sample_custom_query_4():
148148
"""
149149
Returns a sample query instance for the custom dataset
150150
"""
151151
return pd.DataFrame({'Categorical': ['c'], 'Numerical': [13]})
152152

153153

154-
@pytest.fixture
154+
@pytest.fixture()
155155
def sample_custom_query_5():
156156
"""
157157
Returns a sample query instance for the custom dataset
158158
"""
159159
return pd.DataFrame({'X': ['d'], 'Numerical': [25]})
160160

161161

162-
@pytest.fixture
162+
@pytest.fixture()
163163
def sample_custom_query_6():
164164
"""
165165
Returns a sample query instance for the custom dataset including Outcome
166166
"""
167167
return pd.DataFrame({'Categorical': ['c'], 'Numerical': [13], 'Outcome': 0})
168168

169169

170-
@pytest.fixture
170+
@pytest.fixture()
171171
def sample_custom_query_index():
172172
"""
173173
Returns a sample query instance for the custom dataset
174174
"""
175175
return pd.DataFrame({'Categorical': ['a'], 'Numerical': [88]})
176176

177177

178-
@pytest.fixture
178+
@pytest.fixture()
179179
def sample_custom_query_10():
180180
"""
181181
Returns a sample query instance for the custom dataset
@@ -188,7 +188,7 @@ def sample_custom_query_10():
188188
)
189189

190190

191-
@pytest.fixture
191+
@pytest.fixture()
192192
def sample_counterfactual_example_dummy():
193193
"""
194194
Returns a sample counterfactual example
@@ -208,7 +208,7 @@ def sample_counterfactual_example_dummy():
208208
)
209209

210210

211-
@pytest.fixture
211+
@pytest.fixture()
212212
def create_iris_data():
213213
iris = load_iris()
214214
x_train, x_test, y_train, y_test = train_test_split(
@@ -218,7 +218,7 @@ def create_iris_data():
218218
return x_train, x_test, y_train, y_test, feature_names, classes
219219

220220

221-
@pytest.fixture
221+
@pytest.fixture()
222222
def create_housing_data():
223223
housing = fetch_california_housing()
224224
x_train, x_test, y_train, y_test = train_test_split(

tests/test_counterfactual_explanations.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_sorted_local_importance_counterfactual_explanations(self):
109109
list(counterfactual_explanations.local_importance[index].keys())
110110

111111

112-
@pytest.fixture
112+
@pytest.fixture()
113113
def random_binary_classification_exp_object():
114114
backend = 'sklearn'
115115
dataset = helpers.load_custom_testing_dataset()
@@ -150,7 +150,7 @@ def verify_counterfactual_explanations(self, counterfactual_explanations,
150150
assert counterfactual_explanations.summary_importance is None
151151

152152
@pytest.mark.parametrize("version", ['1.0', '2.0'])
153-
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 2)])
153+
@pytest.mark.parametrize(("desired_class", "total_CFs"), [(0, 2)])
154154
def test_counterfactual_explanations_output(self, desired_class,
155155
sample_custom_query_1, total_CFs,
156156
version):
@@ -172,7 +172,7 @@ def test_counterfactual_explanations_output(self, desired_class,
172172
assert recovered_counterfactual_explanations == counterfactual_explanations
173173

174174
@pytest.mark.parametrize("version", ['1.0', '2.0'])
175-
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)])
175+
@pytest.mark.parametrize(("desired_class", "total_CFs"), [(0, 10)])
176176
def test_local_importance_output(self, desired_class, sample_custom_query_1,
177177
total_CFs, version):
178178
counterfactual_explanations = self.exp.local_feature_importance(
@@ -195,7 +195,7 @@ def test_local_importance_output(self, desired_class, sample_custom_query_1,
195195
assert recovered_counterfactual_explanations == counterfactual_explanations
196196

197197
@pytest.mark.parametrize("version", ['1.0', '2.0'])
198-
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)])
198+
@pytest.mark.parametrize(("desired_class", "total_CFs"), [(0, 10)])
199199
def test_summary_importance_output(self, desired_class, sample_custom_query_10,
200200
total_CFs, version):
201201
counterfactual_explanations = self.exp.global_feature_importance(
@@ -242,7 +242,7 @@ def test_empty_counterfactual_explanations_object(self, version):
242242
assert counterfactual_explanations == recovered_counterfactual_explanations
243243

244244
@pytest.mark.parametrize("version", ['1.0', '2.0'])
245-
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 2)])
245+
@pytest.mark.parametrize(("desired_class", "total_CFs"), [(0, 2)])
246246
def test_no_counterfactuals_found(self, desired_class,
247247
sample_custom_query_1, total_CFs,
248248
version):
@@ -261,7 +261,7 @@ def test_no_counterfactuals_found(self, desired_class,
261261
assert counterfactual_explanations == recovered_counterfactual_explanations
262262

263263
@pytest.mark.parametrize("version", ['1.0', '2.0'])
264-
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)])
264+
@pytest.mark.parametrize(("desired_class", "total_CFs"), [(0, 10)])
265265
def test_no_counterfactuals_found_local_importance(self, desired_class,
266266
sample_custom_query_1, total_CFs,
267267
version):
@@ -282,7 +282,7 @@ def test_no_counterfactuals_found_local_importance(self, desired_class,
282282
assert counterfactual_explanations == recovered_counterfactual_explanations
283283

284284
@pytest.mark.parametrize("version", ['1.0', '2.0'])
285-
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)])
285+
@pytest.mark.parametrize(("desired_class", "total_CFs"), [(0, 10)])
286286
def test_no_counterfactuals_found_summary_importance(self, desired_class,
287287
sample_custom_query_10, total_CFs,
288288
version):

tests/test_data_interface/test_private_data_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import dice_ml
66

77

8-
@pytest.fixture
8+
@pytest.fixture()
99
def data_object():
1010
features_dict = OrderedDict(
1111
[('age', [17, 90]),

tests/test_data_interface/test_public_data_interface.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from dice_ml.utils.exception import UserConfigValidationException
1111

1212

13-
@pytest.fixture
13+
@pytest.fixture()
1414
def data_object():
1515
dataset = helpers.load_adult_income_dataset()
1616
return dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'],
@@ -67,16 +67,19 @@ def test_invalid_continuous_features(self, data_type):
6767
iris = load_iris(as_frame=True)
6868
dataset = iris.frame
6969

70+
import re
7071
if data_type == DataTypeCombinations.Incorrect:
71-
with pytest.raises(ValueError) as ve:
72+
with pytest.raises(
73+
ValueError,
74+
match=re.escape("should provide the name(s) of continuous features in the data as a list")):
7275
dice_ml.Data(dataframe=dataset, continuous_features=np.array(iris.feature_names),
7376
outcome_name='target')
74-
assert "should provide the name(s) of continuous features in the data as a list" in str(ve)
7577
elif data_type == DataTypeCombinations.AsNone:
76-
with pytest.raises(ValueError) as ve:
78+
with pytest.raises(
79+
ValueError,
80+
match=re.escape("should provide the name(s) of continuous features in the data as a list")):
7781
dice_ml.Data(dataframe=dataset, continuous_features=None,
7882
outcome_name='target')
79-
assert "should provide the name(s) of continuous features in the data as a list" in str(ve)
8083
else:
8184
with pytest.raises(
8285
ValueError,

tests/test_dice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def test_minimum_query_instances(self):
5454
pytest.importorskip('sklearn')
5555
backend = 'sklearn'
5656
exp = self._get_exp(backend)
57+
query_instances = helpers.load_adult_income_dataset().drop("income", axis=1)[0:1]
5758
with pytest.raises(UserConfigValidationException):
58-
query_instances = helpers.load_adult_income_dataset().drop("income", axis=1)[0:1]
5959
exp.global_feature_importance(query_instances)
6060

6161
def test_unsupported_sampling_strategy(self):

0 commit comments

Comments
 (0)