@@ -47,7 +47,7 @@ def _initiate_exp_object(self, KD_binary_classification_exp_object):
4747 self .data_df_copy = self .exp .data_interface .data_df .copy ()
4848
4949 # When a query's feature value is not within the permitted range and the feature is not allowed to vary
50- @pytest .mark .parametrize ("desired_range, desired_class, total_CFs, features_to_vary, permitted_range" ,
50+ @pytest .mark .parametrize (( "desired_range" , " desired_class" , " total_CFs" , " features_to_vary" , " permitted_range") ,
5151 [(None , 0 , 4 , ['Numerical' ], {'Categorical' : ['b' , 'c' ]})])
5252 def test_invalid_query_instance (self , desired_range , desired_class , sample_custom_query_1 , total_CFs ,
5353 features_to_vary , permitted_range ):
@@ -59,20 +59,20 @@ def test_invalid_query_instance(self, desired_range, desired_class, sample_custo
5959 features_to_vary = features_to_vary , permitted_range = permitted_range )
6060
6161 # Verifying the output of the KD tree
62- @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 1 )])
63- @pytest .mark .parametrize (' posthoc_sparsity_algorithm' , ['linear' , 'binary' , None ])
62+ @pytest .mark .parametrize (( "desired_class" , " total_CFs") , [(0 , 1 )])
63+ @pytest .mark .parametrize (( " posthoc_sparsity_algorithm" ) , ['linear' , 'binary' , None ])
6464 def test_KD_tree_output (self , desired_class , sample_custom_query_1 , total_CFs , posthoc_sparsity_algorithm ):
6565 self .exp ._generate_counterfactuals (query_instance = sample_custom_query_1 , desired_class = desired_class ,
6666 total_CFs = total_CFs ,
6767 posthoc_sparsity_algorithm = posthoc_sparsity_algorithm )
6868 self .exp .final_cfs_df .Numerical = self .exp .final_cfs_df .Numerical .astype (int )
6969 expected_output = self .exp .data_interface .data_df
7070
71- assert all (self .exp .final_cfs_df .Numerical == expected_output .Numerical [0 ]) and \
72- all (self .exp .final_cfs_df .Categorical == expected_output .Categorical [0 ])
71+ assert all (self .exp .final_cfs_df .Numerical == expected_output .Numerical [0 ])
72+ assert all (self .exp .final_cfs_df .Categorical == expected_output .Categorical [0 ])
7373
7474 # Verifying the output of the KD tree
75- @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 1 )])
75+ @pytest .mark .parametrize (( "desired_class" , " total_CFs") , [(0 , 1 )])
7676 def test_KD_tree_counterfactual_explanations_output (self , desired_class , sample_custom_query_1 , total_CFs ):
7777 counterfactual_explanations = self .exp .generate_counterfactuals (
7878 query_instances = sample_custom_query_1 , desired_class = desired_class ,
@@ -81,35 +81,35 @@ def test_KD_tree_counterfactual_explanations_output(self, desired_class, sample_
8181 assert counterfactual_explanations is not None
8282
8383 # Testing that the features_to_vary argument actually varies only the features that you wish to vary
84- @pytest .mark .parametrize ("desired_class, total_CFs, features_to_vary" , [(0 , 1 , ["Numerical" ])])
84+ @pytest .mark .parametrize (( "desired_class" , " total_CFs" , " features_to_vary") , [(0 , 1 , ["Numerical" ])])
8585 def test_features_to_vary (self , desired_class , sample_custom_query_2 , total_CFs , features_to_vary ):
8686 self .exp ._generate_counterfactuals (query_instance = sample_custom_query_2 , desired_class = desired_class ,
8787 total_CFs = total_CFs , features_to_vary = features_to_vary )
8888 self .exp .final_cfs_df .Numerical = self .exp .final_cfs_df .Numerical .astype (int )
8989 expected_output = self .exp .data_interface .data_df
9090
91- assert all (self .exp .final_cfs_df .Numerical == expected_output .Numerical [1 ]) and \
92- all (self .exp .final_cfs_df .Categorical == expected_output .Categorical [1 ])
91+ assert all (self .exp .final_cfs_df .Numerical == expected_output .Numerical [1 ])
92+ assert all (self .exp .final_cfs_df .Categorical == expected_output .Categorical [1 ])
9393
9494 # Testing that the permitted_range argument actually varies the features only within the permitted_range
95- @pytest .mark .parametrize ("desired_class, total_CFs, permitted_range" , [(0 , 1 , {'Numerical' : [1000 , 10000 ]})])
95+ @pytest .mark .parametrize (( "desired_class" , " total_CFs" , " permitted_range") , [(0 , 1 , {'Numerical' : [1000 , 10000 ]})])
9696 def test_permitted_range (self , desired_class , sample_custom_query_2 , total_CFs , permitted_range ):
9797 self .exp ._generate_counterfactuals (query_instance = sample_custom_query_2 , desired_class = desired_class ,
9898 total_CFs = total_CFs , permitted_range = permitted_range )
9999 self .exp .final_cfs_df .Numerical = self .exp .final_cfs_df .Numerical .astype (int )
100100 expected_output = self .exp .data_interface .data_df
101- assert all (self .exp .final_cfs_df .Numerical == expected_output .Numerical [1 ]) and \
102- all (self .exp .final_cfs_df .Categorical == expected_output .Categorical [1 ])
101+ assert all (self .exp .final_cfs_df .Numerical == expected_output .Numerical [1 ])
102+ assert all (self .exp .final_cfs_df .Categorical == expected_output .Categorical [1 ])
103103
104104 # Testing if you can provide permitted_range for categorical variables
105- @pytest .mark .parametrize ("desired_class, total_CFs, permitted_range" , [(0 , 4 , {'Categorical' : ['b' , 'c' ]})])
105+ @pytest .mark .parametrize (( "desired_class" , " total_CFs" , " permitted_range") , [(0 , 4 , {'Categorical' : ['b' , 'c' ]})])
106106 def test_permitted_range_categorical (self , desired_class , sample_custom_query_2 , total_CFs , permitted_range ):
107107 self .exp ._generate_counterfactuals (query_instance = sample_custom_query_2 , desired_class = desired_class ,
108108 total_CFs = total_CFs , permitted_range = permitted_range )
109109 assert all (i in permitted_range ["Categorical" ] for i in self .exp .final_cfs_df .Categorical .values )
110110
111111 # Ensuring that there are no duplicates in the resulting counterfactuals even if the dataset has duplicates
112- @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 2 )])
112+ @pytest .mark .parametrize (( "desired_class" , " total_CFs") , [(0 , 2 )])
113113 def test_duplicates (self , desired_class , sample_custom_query_4 , total_CFs ):
114114 self .exp ._generate_counterfactuals (query_instance = sample_custom_query_4 , total_CFs = total_CFs ,
115115 desired_class = desired_class )
@@ -123,8 +123,8 @@ def test_duplicates(self, desired_class, sample_custom_query_4, total_CFs):
123123 assert all (self .exp .final_cfs_df == expected_output )
124124
125125 # Testing for index returned
126- @pytest .mark .parametrize ("desired_class, total_CFs" , [(0 , 1 )])
127- @pytest .mark .parametrize (' posthoc_sparsity_algorithm' , ['linear' , 'binary' , None ])
126+ @pytest .mark .parametrize (( "desired_class" , " total_CFs") , [(0 , 1 )])
127+ @pytest .mark .parametrize (( " posthoc_sparsity_algorithm" ) , ['linear' , 'binary' , None ])
128128 def test_index (self , desired_class , sample_custom_query_index , total_CFs , posthoc_sparsity_algorithm ):
129129 self .exp ._generate_counterfactuals (query_instance = sample_custom_query_index , total_CFs = total_CFs ,
130130 desired_class = desired_class ,
@@ -139,8 +139,8 @@ def _initiate_exp_object(self, KD_multi_classification_exp_object):
139139 self .data_df_copy = self .exp_multi .data_interface .data_df .copy ()
140140
141141 # Testing that the output of multiclass classification lies in the desired_class
142- @pytest .mark .parametrize ("desired_class, total_CFs" , [(2 , 3 )])
143- @pytest .mark .parametrize (' posthoc_sparsity_algorithm' , ['linear' , 'binary' , None ])
142+ @pytest .mark .parametrize (( "desired_class" , " total_CFs") , [(2 , 3 )])
143+ @pytest .mark .parametrize (( " posthoc_sparsity_algorithm" ) , ['linear' , 'binary' , None ])
144144 def test_KD_tree_output (self , desired_class , sample_custom_query_2 , total_CFs ,
145145 posthoc_sparsity_algorithm ):
146146 self .exp_multi ._generate_counterfactuals (query_instance = sample_custom_query_2 , total_CFs = total_CFs ,
@@ -156,9 +156,9 @@ def _initiate_exp_object(self, KD_regression_exp_object):
156156 self .data_df_copy = self .exp_regr .data_interface .data_df .copy ()
157157
158158 # Testing that the output of regression lies in the desired_range
159- @pytest .mark .parametrize ("desired_range, total_CFs" , [([1 , 2.8 ], 6 )])
160- @pytest .mark .parametrize ("version" , ['2.0' , '1.0' ])
161- @pytest .mark .parametrize (' posthoc_sparsity_algorithm' , ['linear' , 'binary' , None ])
159+ @pytest .mark .parametrize (( "desired_range" , " total_CFs") , [([1 , 2.8 ], 6 )])
160+ @pytest .mark .parametrize (( "version" ) , ['2.0' , '1.0' ])
161+ @pytest .mark .parametrize (( " posthoc_sparsity_algorithm" ) , ['linear' , 'binary' , None ])
162162 def test_KD_tree_output (self , desired_range , sample_custom_query_2 , total_CFs , version , posthoc_sparsity_algorithm ):
163163 cf_examples = self .exp_regr ._generate_counterfactuals (query_instance = sample_custom_query_2 , total_CFs = total_CFs ,
164164 desired_range = desired_range ,
@@ -173,7 +173,7 @@ def test_KD_tree_output(self, desired_range, sample_custom_query_2, total_CFs, v
173173 assert recovered_cf_examples is not None
174174 assert cf_examples == recovered_cf_examples
175175
176- @pytest .mark .parametrize ("desired_range, total_CFs" , [([1 , 2.8 ], 6 )])
176+ @pytest .mark .parametrize (( "desired_range" , " total_CFs") , [([1 , 2.8 ], 6 )])
177177 def test_KD_tree_counterfactual_explanations_output (self , desired_range , sample_custom_query_2 ,
178178 total_CFs ):
179179 counterfactual_explanations = self .exp_regr .generate_counterfactuals (
@@ -189,7 +189,7 @@ def test_KD_tree_counterfactual_explanations_output(self, desired_range, sample_
189189 assert counterfactual_explanations == recovered_counterfactual_explanations
190190
191191 # Testing for 0 CFs needed
192- @pytest .mark .parametrize ("desired_class, desired_range, total_CFs" , [(0 , [1 , 2.8 ], 0 )])
192+ @pytest .mark .parametrize (( "desired_class" , " desired_range" , " total_CFs") , [(0 , [1 , 2.8 ], 0 )])
193193 def test_zero_cfs (self , desired_class , desired_range , sample_custom_query_4 , total_CFs ):
194194 self .exp_regr ._generate_counterfactuals (query_instance = sample_custom_query_4 , total_CFs = total_CFs ,
195195 desired_range = desired_range )
0 commit comments