1010
1111class TestCounterfactualExplanations :
1212
13- @pytest .mark .parametrize ("version" , ['1.0' , '2.0' ])
14- def test_serialization_deserialization_counterfactual_explanations_class (self , version ):
15-
16- counterfactual_explanations = CounterfactualExplanations (
17- cf_examples_list = [],
18- local_importance = None ,
19- summary_importance = None ,
20- version = version )
21- assert counterfactual_explanations .cf_examples_list is not None
22- assert len (counterfactual_explanations .cf_examples_list ) == 0
23- assert counterfactual_explanations .summary_importance is None
24- assert counterfactual_explanations .local_importance is None
25- assert counterfactual_explanations .metadata is not None
26- assert counterfactual_explanations .metadata ['version' ] is not None
27- assert counterfactual_explanations .metadata ['version' ] == version
28-
29- counterfactual_explanations_as_json = counterfactual_explanations .to_json ()
30- assert counterfactual_explanations_as_json is not None
31-
32- recovered_counterfactual_explanations = CounterfactualExplanations .from_json (
33- counterfactual_explanations_as_json )
34-
35- assert recovered_counterfactual_explanations is not None
36- assert recovered_counterfactual_explanations .metadata ['version' ] == version
37- assert counterfactual_explanations == recovered_counterfactual_explanations
38-
3913 def test_sorted_summary_importance_counterfactual_explanations (self ):
4014
4115 unsorted_summary_importance = {
@@ -132,21 +106,6 @@ def test_sorted_local_importance_counterfactual_explanations(self):
132106 assert list (unsorted_local_importance [index ].keys ()) != list (counterfactual_explanations .local_importance [index ].keys ())
133107 assert list (sorted_local_importance [index ].keys ()) == list (counterfactual_explanations .local_importance [index ].keys ())
134108
135- @pytest .mark .parametrize ('version' , ['3.0' , '' ])
136- def test_unsupported_versions_json_input (self , version ):
137- json_str = json .dumps ({'metadata' : {'version' : version }})
138- with pytest .raises (UserConfigValidationException ) as ucve :
139- CounterfactualExplanations .from_json (json_str )
140-
141- assert "Incompatible version {} found in json input" .format (version ) in str (ucve )
142-
143- json_str = json .dumps ({'metadata' : {'versio' : version }})
144- with pytest .raises (UserConfigValidationException ) as ucve :
145- CounterfactualExplanations .from_json (json_str )
146-
147- assert "No version field in the json input" in str (ucve )
148-
149-
150109
151110@pytest .fixture
152111def random_binary_classification_exp_object ():
@@ -165,90 +124,144 @@ def _initiate_exp_object(self, random_binary_classification_exp_object):
165124 self .exp = random_binary_classification_exp_object # explainer object
166125 self .data_df_copy = self .exp .data_interface .data_df .copy ()
167126
168- @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 2 )])
169127 @pytest .mark .parametrize ("version" , ['1.0' , '2.0' ])
128+ def verify_counterfactual_explanations (self , counterfactual_explanations ,
129+ total_CFs , num_query_points , version ,
130+ local_importance_available = False ,
131+ summary_importance_available = False ):
132+ assert counterfactual_explanations is not None
133+ assert counterfactual_explanations .cf_examples_list is not None
134+ assert len (counterfactual_explanations .cf_examples_list ) == num_query_points
135+ if total_CFs is not None :
136+ assert counterfactual_explanations .cf_examples_list [0 ].final_cfs_df .shape [0 ] == total_CFs
137+ assert counterfactual_explanations .metadata is not None
138+ assert counterfactual_explanations .metadata ['version' ] is not None
139+ counterfactual_explanations .metadata ['version' ] = version
140+ if local_importance_available :
141+ assert counterfactual_explanations .local_importance is not None
142+ assert len (counterfactual_explanations .local_importance ) == num_query_points
143+ else :
144+ assert counterfactual_explanations .local_importance is None
145+ if summary_importance_available :
146+ assert counterfactual_explanations .summary_importance is not None
147+ else :
148+ assert counterfactual_explanations .summary_importance is None
149+
150+ @pytest .mark .parametrize ("version" , ['1.0' , '2.0' ])
151+ @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 2 )])
170152 def test_random_counterfactual_explanations_output (self , desired_class ,
171153 sample_custom_query_1 , total_CFs ,
172154 version ):
173155 counterfactual_explanations = self .exp .generate_counterfactuals (
174156 query_instances = sample_custom_query_1 , desired_class = desired_class ,
175157 total_CFs = total_CFs )
176158
177- assert counterfactual_explanations is not None
178- assert len (counterfactual_explanations .cf_examples_list ) == sample_custom_query_1 .shape [0 ]
179- assert counterfactual_explanations .cf_examples_list [0 ].final_cfs_df .shape [0 ] == total_CFs
180- assert counterfactual_explanations .local_importance is None
181- assert counterfactual_explanations .summary_importance is None
159+ self .verify_counterfactual_explanations (counterfactual_explanations , total_CFs ,
160+ sample_custom_query_1 .shape [0 ], version )
182161
183- counterfactual_explanations .metadata ['version' ] = version
184162 json_output = counterfactual_explanations .to_json ()
185163 assert json_output is not None
186164 assert json .loads (json_output ).get ('metadata' ).get ('version' ) == version
187165
188166 recovered_counterfactual_explanations = CounterfactualExplanations .from_json (json_output )
189- assert recovered_counterfactual_explanations is not None
190- assert recovered_counterfactual_explanations == counterfactual_explanations
191- assert recovered_counterfactual_explanations .metadata ['version' ] == version
167+ self .verify_counterfactual_explanations (counterfactual_explanations , total_CFs ,
168+ sample_custom_query_1 .shape [0 ], version )
192169
193- assert len (recovered_counterfactual_explanations .cf_examples_list ) == sample_custom_query_1 .shape [0 ]
194- assert recovered_counterfactual_explanations .cf_examples_list [0 ].final_cfs_df .shape [0 ] == total_CFs
195- assert recovered_counterfactual_explanations .local_importance is None
196- assert recovered_counterfactual_explanations .summary_importance is None
170+ assert recovered_counterfactual_explanations == counterfactual_explanations
197171
198- @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 10 )])
199172 @pytest .mark .parametrize ("version" , ['1.0' , '2.0' ])
173+ @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 10 )])
200174 def test_random_local_importance_output (self , desired_class , sample_custom_query_1 ,
201175 total_CFs , version ):
202176 counterfactual_explanations = self .exp .local_feature_importance (
203177 query_instances = sample_custom_query_1 , desired_class = desired_class ,
204178 total_CFs = total_CFs )
205179
206- assert counterfactual_explanations is not None
207- assert len (counterfactual_explanations .cf_examples_list ) == sample_custom_query_1 .shape [0 ]
208- assert counterfactual_explanations .cf_examples_list [0 ].final_cfs_df .shape [0 ] == total_CFs
209- assert counterfactual_explanations .local_importance is not None
210- assert counterfactual_explanations .summary_importance is None
180+ self .verify_counterfactual_explanations (counterfactual_explanations , total_CFs ,
181+ sample_custom_query_1 .shape [0 ], version ,
182+ local_importance_available = True )
211183
212- counterfactual_explanations .metadata ['version' ] = version
213184 json_output = counterfactual_explanations .to_json ()
214185 assert json_output is not None
215186 assert json .loads (json_output ).get ('metadata' ).get ('version' ) == version
216187
217188 recovered_counterfactual_explanations = CounterfactualExplanations .from_json (json_output )
218- assert recovered_counterfactual_explanations is not None
219- assert recovered_counterfactual_explanations == counterfactual_explanations
220- assert recovered_counterfactual_explanations . metadata [ 'version' ] == version
189+ self . verify_counterfactual_explanations ( counterfactual_explanations , total_CFs ,
190+ sample_custom_query_1 . shape [ 0 ], version ,
191+ local_importance_available = True )
221192
222- assert len (recovered_counterfactual_explanations .cf_examples_list ) == sample_custom_query_1 .shape [0 ]
223- assert recovered_counterfactual_explanations .cf_examples_list [0 ].final_cfs_df .shape [0 ] == total_CFs
224- assert recovered_counterfactual_explanations .local_importance is not None
225- assert counterfactual_explanations .summary_importance is None
193+ assert recovered_counterfactual_explanations == counterfactual_explanations
226194
227- @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 10 )])
228195 @pytest .mark .parametrize ("version" , ['1.0' , '2.0' ])
196+ @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 10 )])
229197 def test_random_summary_importance_output (self , desired_class , sample_custom_query_10 ,
230198 total_CFs , version ):
231199 counterfactual_explanations = self .exp .global_feature_importance (
232200 query_instances = sample_custom_query_10 , desired_class = desired_class ,
233201 total_CFs = total_CFs )
234202
235- assert counterfactual_explanations is not None
236- assert len (counterfactual_explanations .cf_examples_list ) == sample_custom_query_10 .shape [0 ]
237- assert counterfactual_explanations .cf_examples_list [0 ].final_cfs_df .shape [0 ] == total_CFs
238- assert counterfactual_explanations .local_importance is not None
239- assert counterfactual_explanations .summary_importance is not None
203+ self .verify_counterfactual_explanations (counterfactual_explanations , total_CFs ,
204+ sample_custom_query_10 .shape [0 ], version ,
205+ local_importance_available = True ,
206+ summary_importance_available = True )
240207
241- counterfactual_explanations .metadata ['version' ] = version
242208 json_output = counterfactual_explanations .to_json ()
243209 assert json_output is not None
244210 assert json .loads (json_output ).get ('metadata' ).get ('version' ) == version
245211
246212 recovered_counterfactual_explanations = CounterfactualExplanations .from_json (json_output )
247- assert recovered_counterfactual_explanations is not None
213+ self .verify_counterfactual_explanations (counterfactual_explanations , total_CFs ,
214+ sample_custom_query_10 .shape [0 ], version ,
215+ local_importance_available = True ,
216+ summary_importance_available = True )
217+
248218 assert recovered_counterfactual_explanations == counterfactual_explanations
249- assert recovered_counterfactual_explanations .metadata ['version' ] == version
250219
251- assert len (recovered_counterfactual_explanations .cf_examples_list ) == sample_custom_query_10 .shape [0 ]
252- assert recovered_counterfactual_explanations .cf_examples_list [0 ].final_cfs_df .shape [0 ] == total_CFs
253- assert recovered_counterfactual_explanations .local_importance is not None
254- assert counterfactual_explanations .summary_importance is not None
220+ @pytest .mark .parametrize ("version" , ['1.0' , '2.0' ])
221+ def test_empty_counterfactual_explanations_object (self , version ):
222+
223+ counterfactual_explanations = CounterfactualExplanations (
224+ cf_examples_list = [],
225+ local_importance = None ,
226+ summary_importance = None ,
227+ version = version )
228+ self .verify_counterfactual_explanations (counterfactual_explanations , None ,
229+ 0 , version )
230+
231+ counterfactual_explanations_as_json = counterfactual_explanations .to_json ()
232+ assert counterfactual_explanations_as_json is not None
233+
234+ recovered_counterfactual_explanations = CounterfactualExplanations .from_json (
235+ counterfactual_explanations_as_json )
236+
237+ self .verify_counterfactual_explanations (recovered_counterfactual_explanations , None ,
238+ 0 , version )
239+
240+ assert counterfactual_explanations == recovered_counterfactual_explanations
241+
242+ @pytest .mark .parametrize ('unsupported_version' , ['3.0' , '' ])
243+ def test_unsupported_versions_from_json (self , unsupported_version ):
244+ json_str = json .dumps ({'metadata' : {'version' : unsupported_version }})
245+ with pytest .raises (UserConfigValidationException ) as ucve :
246+ CounterfactualExplanations .from_json (json_str )
247+
248+ assert "Incompatible version {} found in json input" .format (unsupported_version ) in str (ucve )
249+
250+ json_str = json .dumps ({'metadata' : {'versio' : unsupported_version }})
251+ with pytest .raises (UserConfigValidationException ) as ucve :
252+ CounterfactualExplanations .from_json (json_str )
253+
254+ assert "No version field in the json input" in str (ucve )
255+
256+ @pytest .mark .parametrize ('unsupported_version' , ['3.0' , '' ])
257+ def test_unsupported_versions_to_json (self , unsupported_version ):
258+ counterfactual_explanations = CounterfactualExplanations (
259+ cf_examples_list = [],
260+ local_importance = None ,
261+ summary_importance = None ,
262+ version = unsupported_version )
263+
264+ with pytest .raises (UserConfigValidationException ) as ucve :
265+ counterfactual_explanations .to_json ()
266+
267+ assert "Unsupported serialization version {}" .format (unsupported_version ) in str (ucve )
0 commit comments