Skip to content

Commit f95c777

Browse files
authored
Remove unused method and refactor tests (#126)
Signed-off-by: gaugup <gaugup@microsoft.com>
1 parent 1f287ab commit f95c777

2 files changed

Lines changed: 97 additions & 91 deletions

File tree

dice_ml/counterfactual_explanations.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,6 @@ def __eq__(self, other_cf):
6767
self.metadata == other_cf.metadata
6868
return False
6969

70-
@property
71-
def __dict__(self):
72-
return {'cf_examples_list': self.cf_examples_list,
73-
'local_importance': self.local_importance,
74-
'summary_importance': self.summary_importance,
75-
'metadata': self.metadata}
76-
7770
@property
7871
def cf_examples_list(self):
7972
return self._cf_examples_list
@@ -220,8 +213,8 @@ def to_json(self):
220213
entire_dict, version=serialization_version)
221214
return json.dumps(entire_dict)
222215
else:
223-
raise Exception("Unsupported serialization version {}".format(
224-
serialization_version))
216+
raise UserConfigValidationException(
217+
"Unsupported serialization version {}".format(serialization_version))
225218

226219
@staticmethod
227220
def from_json(json_str):

tests/test_counterfactual_explanations.py

Lines changed: 95 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -10,32 +10,6 @@
1010

1111
class TestCounterfactualExplanations:
1212

13-
@pytest.mark.parametrize("version", ['1.0', '2.0'])
14-
def test_serialization_deserialization_counterfactual_explanations_class(self, version):
15-
16-
counterfactual_explanations = CounterfactualExplanations(
17-
cf_examples_list=[],
18-
local_importance=None,
19-
summary_importance=None,
20-
version=version)
21-
assert counterfactual_explanations.cf_examples_list is not None
22-
assert len(counterfactual_explanations.cf_examples_list) == 0
23-
assert counterfactual_explanations.summary_importance is None
24-
assert counterfactual_explanations.local_importance is None
25-
assert counterfactual_explanations.metadata is not None
26-
assert counterfactual_explanations.metadata['version'] is not None
27-
assert counterfactual_explanations.metadata['version'] == version
28-
29-
counterfactual_explanations_as_json = counterfactual_explanations.to_json()
30-
assert counterfactual_explanations_as_json is not None
31-
32-
recovered_counterfactual_explanations = CounterfactualExplanations.from_json(
33-
counterfactual_explanations_as_json)
34-
35-
assert recovered_counterfactual_explanations is not None
36-
assert recovered_counterfactual_explanations.metadata['version'] == version
37-
assert counterfactual_explanations == recovered_counterfactual_explanations
38-
3913
def test_sorted_summary_importance_counterfactual_explanations(self):
4014

4115
unsorted_summary_importance = {
@@ -132,21 +106,6 @@ def test_sorted_local_importance_counterfactual_explanations(self):
132106
assert list(unsorted_local_importance[index].keys()) != list(counterfactual_explanations.local_importance[index].keys())
133107
assert list(sorted_local_importance[index].keys()) == list(counterfactual_explanations.local_importance[index].keys())
134108

135-
@pytest.mark.parametrize('version', ['3.0', ''])
136-
def test_unsupported_versions_json_input(self, version):
137-
json_str = json.dumps({'metadata': {'version': version}})
138-
with pytest.raises(UserConfigValidationException) as ucve:
139-
CounterfactualExplanations.from_json(json_str)
140-
141-
assert "Incompatible version {} found in json input".format(version) in str(ucve)
142-
143-
json_str = json.dumps({'metadata': {'versio': version}})
144-
with pytest.raises(UserConfigValidationException) as ucve:
145-
CounterfactualExplanations.from_json(json_str)
146-
147-
assert "No version field in the json input" in str(ucve)
148-
149-
150109

151110
@pytest.fixture
152111
def random_binary_classification_exp_object():
@@ -165,90 +124,144 @@ def _initiate_exp_object(self, random_binary_classification_exp_object):
165124
self.exp = random_binary_classification_exp_object # explainer object
166125
self.data_df_copy = self.exp.data_interface.data_df.copy()
167126

168-
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 2)])
169127
@pytest.mark.parametrize("version", ['1.0', '2.0'])
128+
def verify_counterfactual_explanations(self, counterfactual_explanations,
129+
total_CFs, num_query_points, version,
130+
local_importance_available=False,
131+
summary_importance_available=False):
132+
assert counterfactual_explanations is not None
133+
assert counterfactual_explanations.cf_examples_list is not None
134+
assert len(counterfactual_explanations.cf_examples_list) == num_query_points
135+
if total_CFs is not None:
136+
assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
137+
assert counterfactual_explanations.metadata is not None
138+
assert counterfactual_explanations.metadata['version'] is not None
139+
counterfactual_explanations.metadata['version'] = version
140+
if local_importance_available:
141+
assert counterfactual_explanations.local_importance is not None
142+
assert len(counterfactual_explanations.local_importance) == num_query_points
143+
else:
144+
assert counterfactual_explanations.local_importance is None
145+
if summary_importance_available:
146+
assert counterfactual_explanations.summary_importance is not None
147+
else:
148+
assert counterfactual_explanations.summary_importance is None
149+
150+
@pytest.mark.parametrize("version", ['1.0', '2.0'])
151+
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 2)])
170152
def test_random_counterfactual_explanations_output(self, desired_class,
171153
sample_custom_query_1, total_CFs,
172154
version):
173155
counterfactual_explanations = self.exp.generate_counterfactuals(
174156
query_instances=sample_custom_query_1, desired_class=desired_class,
175157
total_CFs=total_CFs)
176158

177-
assert counterfactual_explanations is not None
178-
assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0]
179-
assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
180-
assert counterfactual_explanations.local_importance is None
181-
assert counterfactual_explanations.summary_importance is None
159+
self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs,
160+
sample_custom_query_1.shape[0], version)
182161

183-
counterfactual_explanations.metadata['version'] = version
184162
json_output = counterfactual_explanations.to_json()
185163
assert json_output is not None
186164
assert json.loads(json_output).get('metadata').get('version') == version
187165

188166
recovered_counterfactual_explanations = CounterfactualExplanations.from_json(json_output)
189-
assert recovered_counterfactual_explanations is not None
190-
assert recovered_counterfactual_explanations == counterfactual_explanations
191-
assert recovered_counterfactual_explanations.metadata['version'] == version
167+
self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs,
168+
sample_custom_query_1.shape[0], version)
192169

193-
assert len(recovered_counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0]
194-
assert recovered_counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
195-
assert recovered_counterfactual_explanations.local_importance is None
196-
assert recovered_counterfactual_explanations.summary_importance is None
170+
assert recovered_counterfactual_explanations == counterfactual_explanations
197171

198-
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)])
199172
@pytest.mark.parametrize("version", ['1.0', '2.0'])
173+
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)])
200174
def test_random_local_importance_output(self, desired_class, sample_custom_query_1,
201175
total_CFs, version):
202176
counterfactual_explanations = self.exp.local_feature_importance(
203177
query_instances=sample_custom_query_1, desired_class=desired_class,
204178
total_CFs=total_CFs)
205179

206-
assert counterfactual_explanations is not None
207-
assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0]
208-
assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
209-
assert counterfactual_explanations.local_importance is not None
210-
assert counterfactual_explanations.summary_importance is None
180+
self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs,
181+
sample_custom_query_1.shape[0], version,
182+
local_importance_available=True)
211183

212-
counterfactual_explanations.metadata['version'] = version
213184
json_output = counterfactual_explanations.to_json()
214185
assert json_output is not None
215186
assert json.loads(json_output).get('metadata').get('version') == version
216187

217188
recovered_counterfactual_explanations = CounterfactualExplanations.from_json(json_output)
218-
assert recovered_counterfactual_explanations is not None
219-
assert recovered_counterfactual_explanations == counterfactual_explanations
220-
assert recovered_counterfactual_explanations.metadata['version'] == version
189+
self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs,
190+
sample_custom_query_1.shape[0], version,
191+
local_importance_available=True)
221192

222-
assert len(recovered_counterfactual_explanations.cf_examples_list) == sample_custom_query_1.shape[0]
223-
assert recovered_counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
224-
assert recovered_counterfactual_explanations.local_importance is not None
225-
assert counterfactual_explanations.summary_importance is None
193+
assert recovered_counterfactual_explanations == counterfactual_explanations
226194

227-
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)])
228195
@pytest.mark.parametrize("version", ['1.0', '2.0'])
196+
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 10)])
229197
def test_random_summary_importance_output(self, desired_class, sample_custom_query_10,
230198
total_CFs, version):
231199
counterfactual_explanations = self.exp.global_feature_importance(
232200
query_instances=sample_custom_query_10, desired_class=desired_class,
233201
total_CFs=total_CFs)
234202

235-
assert counterfactual_explanations is not None
236-
assert len(counterfactual_explanations.cf_examples_list) == sample_custom_query_10.shape[0]
237-
assert counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
238-
assert counterfactual_explanations.local_importance is not None
239-
assert counterfactual_explanations.summary_importance is not None
203+
self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs,
204+
sample_custom_query_10.shape[0], version,
205+
local_importance_available=True,
206+
summary_importance_available=True)
240207

241-
counterfactual_explanations.metadata['version'] = version
242208
json_output = counterfactual_explanations.to_json()
243209
assert json_output is not None
244210
assert json.loads(json_output).get('metadata').get('version') == version
245211

246212
recovered_counterfactual_explanations = CounterfactualExplanations.from_json(json_output)
247-
assert recovered_counterfactual_explanations is not None
213+
self.verify_counterfactual_explanations(counterfactual_explanations, total_CFs,
214+
sample_custom_query_10.shape[0], version,
215+
local_importance_available=True,
216+
summary_importance_available=True)
217+
248218
assert recovered_counterfactual_explanations == counterfactual_explanations
249-
assert recovered_counterfactual_explanations.metadata['version'] == version
250219

251-
assert len(recovered_counterfactual_explanations.cf_examples_list) == sample_custom_query_10.shape[0]
252-
assert recovered_counterfactual_explanations.cf_examples_list[0].final_cfs_df.shape[0] == total_CFs
253-
assert recovered_counterfactual_explanations.local_importance is not None
254-
assert counterfactual_explanations.summary_importance is not None
220+
@pytest.mark.parametrize("version", ['1.0', '2.0'])
221+
def test_empty_counterfactual_explanations_object(self, version):
222+
223+
counterfactual_explanations = CounterfactualExplanations(
224+
cf_examples_list=[],
225+
local_importance=None,
226+
summary_importance=None,
227+
version=version)
228+
self.verify_counterfactual_explanations(counterfactual_explanations, None,
229+
0, version)
230+
231+
counterfactual_explanations_as_json = counterfactual_explanations.to_json()
232+
assert counterfactual_explanations_as_json is not None
233+
234+
recovered_counterfactual_explanations = CounterfactualExplanations.from_json(
235+
counterfactual_explanations_as_json)
236+
237+
self.verify_counterfactual_explanations(recovered_counterfactual_explanations, None,
238+
0, version)
239+
240+
assert counterfactual_explanations == recovered_counterfactual_explanations
241+
242+
@pytest.mark.parametrize('unsupported_version', ['3.0', ''])
243+
def test_unsupported_versions_from_json(self, unsupported_version):
244+
json_str = json.dumps({'metadata': {'version': unsupported_version}})
245+
with pytest.raises(UserConfigValidationException) as ucve:
246+
CounterfactualExplanations.from_json(json_str)
247+
248+
assert "Incompatible version {} found in json input".format(unsupported_version) in str(ucve)
249+
250+
json_str = json.dumps({'metadata': {'versio': unsupported_version}})
251+
with pytest.raises(UserConfigValidationException) as ucve:
252+
CounterfactualExplanations.from_json(json_str)
253+
254+
assert "No version field in the json input" in str(ucve)
255+
256+
@pytest.mark.parametrize('unsupported_version', ['3.0', ''])
257+
def test_unsupported_versions_to_json(self, unsupported_version):
258+
counterfactual_explanations = CounterfactualExplanations(
259+
cf_examples_list=[],
260+
local_importance=None,
261+
summary_importance=None,
262+
version=unsupported_version)
263+
264+
with pytest.raises(UserConfigValidationException) as ucve:
265+
counterfactual_explanations.to_json()
266+
267+
assert "Unsupported serialization version {}".format(unsupported_version) in str(ucve)

0 commit comments

Comments
 (0)