Skip to content

Commit da3ad79

Browse files
committed
Add tests
Signed-off-by: Gaurav Gupta <gaugup@microsoft.com>
1 parent e7bf06b commit da3ad79

2 files changed

Lines changed: 208 additions & 121 deletions

File tree

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def _validate_counterfactual_configuration(
6060
if posthoc_sparsity_algorithm not in _PostHocSparsityTypes.ALL:
6161
raise UserConfigValidationException(
6262
'The posthoc_sparsity_algorithm should be {0} and not {1}'.format(
63-
','.join(_PostHocSparsityTypes.ALL), posthoc_sparsity_algorithm)
63+
' or '.join(_PostHocSparsityTypes.ALL), posthoc_sparsity_algorithm)
6464
)
6565

6666
if stopping_threshold < 0.0 or stopping_threshold > 1.0:
@@ -250,6 +250,16 @@ def local_feature_importance(self, query_instances, cf_examples_list=None,
250250
the list of counterfactuals per input, local feature importances per
251251
input, and the global feature importance summarized over all inputs.
252252
"""
253+
self._validate_counterfactual_configuration(
254+
query_instances=query_instances,
255+
total_CFs=total_CFs,
256+
desired_class=desired_class,
257+
desired_range=desired_range,
258+
permitted_range=permitted_range, features_to_vary=features_to_vary,
259+
stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param,
260+
posthoc_sparsity_algorithm=posthoc_sparsity_algorithm,
261+
kwargs=kwargs
262+
)
253263
if cf_examples_list is not None:
254264
if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]):
255265
raise UserConfigValidationException(
@@ -299,6 +309,16 @@ def global_feature_importance(self, query_instances, cf_examples_list=None,
299309
the list of counterfactuals per input, local feature importances per
300310
input, and the global feature importance summarized over all inputs.
301311
"""
312+
self._validate_counterfactual_configuration(
313+
query_instances=query_instances,
314+
total_CFs=total_CFs,
315+
desired_class=desired_class,
316+
desired_range=desired_range,
317+
permitted_range=permitted_range, features_to_vary=features_to_vary,
318+
stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param,
319+
posthoc_sparsity_algorithm=posthoc_sparsity_algorithm,
320+
kwargs=kwargs
321+
)
302322
if query_instances is not None and len(query_instances) < 10:
303323
raise UserConfigValidationException(
304324
"The number of query instances should be greater than or equal to 10 "
@@ -355,6 +375,16 @@ def feature_importance(self, query_instances, cf_examples_list=None,
355375
the list of counterfactuals per input, local feature importances per
356376
input, and the global feature importance summarized over all inputs.
357377
"""
378+
self._validate_counterfactual_configuration(
379+
query_instances=query_instances,
380+
total_CFs=total_CFs,
381+
desired_class=desired_class,
382+
desired_range=desired_range,
383+
permitted_range=permitted_range, features_to_vary=features_to_vary,
384+
stopping_threshold=stopping_threshold, posthoc_sparsity_param=posthoc_sparsity_param,
385+
posthoc_sparsity_algorithm=posthoc_sparsity_algorithm,
386+
kwargs=kwargs
387+
)
358388
if cf_examples_list is None:
359389
cf_examples_list = self.generate_counterfactuals(
360390
query_instances, total_CFs,

tests/test_dice_interface/test_explainer_base.py

Lines changed: 177 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -47,23 +47,6 @@ def test_check_any_counterfactuals_computed(
4747
cf_examples_arr = [cf_example_has_cf, cf_example_no_cf]
4848
exp._check_any_counterfactuals_computed(cf_examples_arr=cf_examples_arr)
4949

50-
@pytest.mark.parametrize("desired_class", [1])
51-
def test_zero_totalcfs(
52-
self, desired_class, method, sample_custom_query_1,
53-
custom_public_data_interface,
54-
sklearn_binary_classification_model_interface
55-
):
56-
exp = dice_ml.Dice(
57-
custom_public_data_interface,
58-
sklearn_binary_classification_model_interface,
59-
method=method)
60-
61-
with pytest.raises(UserConfigValidationException):
62-
exp.generate_counterfactuals(
63-
query_instances=[sample_custom_query_1],
64-
total_CFs=0,
65-
desired_class=desired_class)
66-
6750
@pytest.mark.parametrize("desired_class", [1])
6851
def test_local_feature_importance(
6952
self, desired_class, method,
@@ -128,109 +111,6 @@ def test_global_feature_importance(
128111

129112
self._verify_feature_importance(global_importance.summary_importance)
130113

131-
@pytest.mark.parametrize("desired_class", [1])
132-
def test_global_feature_importance_error_conditions_with_insufficient_query_points(
133-
self, desired_class, method,
134-
sample_custom_query_1,
135-
custom_public_data_interface,
136-
sklearn_binary_classification_model_interface):
137-
exp = dice_ml.Dice(
138-
custom_public_data_interface,
139-
sklearn_binary_classification_model_interface,
140-
method=method)
141-
142-
cf_explanations = exp.generate_counterfactuals(
143-
query_instances=sample_custom_query_1,
144-
total_CFs=15,
145-
desired_class=desired_class)
146-
147-
with pytest.raises(
148-
UserConfigValidationException,
149-
match="The number of points for which counterfactuals generated should be "
150-
"greater than or equal to 10 "
151-
"to compute global feature importance"):
152-
exp.global_feature_importance(
153-
query_instances=None,
154-
cf_examples_list=cf_explanations.cf_examples_list)
155-
156-
with pytest.raises(
157-
UserConfigValidationException,
158-
match="The number of query instances should be greater than or equal to 10 "
159-
"to compute global feature importance over all query points"):
160-
exp.global_feature_importance(
161-
query_instances=sample_custom_query_1,
162-
total_CFs=15,
163-
desired_class=desired_class)
164-
165-
@pytest.mark.parametrize("desired_class", [1])
166-
def test_global_feature_importance_error_conditions_with_insufficient_cfs_per_query_point(
167-
self, desired_class, method,
168-
sample_custom_query_10,
169-
custom_public_data_interface,
170-
sklearn_binary_classification_model_interface):
171-
exp = dice_ml.Dice(
172-
custom_public_data_interface,
173-
sklearn_binary_classification_model_interface,
174-
method=method)
175-
176-
cf_explanations = exp.generate_counterfactuals(
177-
query_instances=sample_custom_query_10,
178-
total_CFs=1,
179-
desired_class=desired_class)
180-
181-
with pytest.raises(
182-
UserConfigValidationException,
183-
match="The number of counterfactuals generated per query instance should be "
184-
"greater than or equal to 10 "
185-
"to compute global feature importance over all query points"):
186-
exp.global_feature_importance(
187-
query_instances=None,
188-
cf_examples_list=cf_explanations.cf_examples_list)
189-
190-
with pytest.raises(
191-
UserConfigValidationException,
192-
match="The number of counterfactuals requested per query instance should be greater "
193-
"than or equal to 10 "
194-
"to compute global feature importance over all query points"):
195-
exp.global_feature_importance(
196-
query_instances=sample_custom_query_10,
197-
total_CFs=1,
198-
desired_class=desired_class)
199-
200-
@pytest.mark.parametrize("desired_class", [1])
201-
def test_local_feature_importance_error_conditions_with_insufficient_cfs_per_query_point(
202-
self, desired_class, method,
203-
sample_custom_query_1,
204-
custom_public_data_interface,
205-
sklearn_binary_classification_model_interface):
206-
exp = dice_ml.Dice(
207-
custom_public_data_interface,
208-
sklearn_binary_classification_model_interface,
209-
method=method)
210-
211-
cf_explanations = exp.generate_counterfactuals(
212-
query_instances=sample_custom_query_1,
213-
total_CFs=1,
214-
desired_class=desired_class)
215-
216-
with pytest.raises(
217-
UserConfigValidationException,
218-
match="The number of counterfactuals generated per query instance should be "
219-
"greater than or equal to 10 to compute feature importance for all query points"):
220-
exp.local_feature_importance(
221-
query_instances=None,
222-
cf_examples_list=cf_explanations.cf_examples_list)
223-
224-
with pytest.raises(
225-
UserConfigValidationException,
226-
match="The number of counterfactuals requested per "
227-
"query instance should be greater than or equal to 10 "
228-
"to compute feature importance for all query points"):
229-
exp.local_feature_importance(
230-
query_instances=sample_custom_query_1,
231-
total_CFs=1,
232-
desired_class=desired_class)
233-
234114
# @pytest.mark.parametrize("desired_class, binary_classification_exp_object_out_of_order",
235115
# [(1, 'random'), (1, 'genetic'), (1, 'kdtree')],
236116
# indirect=['binary_classification_exp_object_out_of_order'])
@@ -545,3 +425,180 @@ class TestExplainerBase:
545425
def test_instantiating_explainer_base(self, public_data_object):
546426
with pytest.raises(TypeError):
547427
ExplainerBase(data_interface=public_data_object)
428+
429+
430+
@pytest.mark.parametrize("method", ['random', 'genetic', 'kdtree'])
431+
class TestExplainerBaseUserConfigValidations:
432+
433+
@pytest.mark.parametrize('explainer_function',
434+
['generate_counterfactuals', 'local_feature_importance',
435+
'feature_importance', 'global_feature_importance'])
436+
def test_generate_counterfactuals_user_config_validations(
437+
self, method, sample_custom_query_2,
438+
custom_public_data_interface,
439+
sklearn_binary_classification_model_interface,
440+
explainer_function):
441+
exp = dice_ml.Dice(
442+
custom_public_data_interface,
443+
sklearn_binary_classification_model_interface,
444+
method=method)
445+
446+
explainer_function = getattr(exp, explainer_function)
447+
with pytest.raises(
448+
UserConfigValidationException,
449+
match=r"The number of counterfactuals generated per query instance \(total_CFs\) "
450+
"should be a positive integer."):
451+
explainer_function(query_instances=sample_custom_query_2,
452+
total_CFs=-10, desired_class='opposite')
453+
454+
with pytest.raises(
455+
UserConfigValidationException,
456+
match=r"The number of counterfactuals generated per query instance \(total_CFs\) "
457+
"should be a positive integer."):
458+
explainer_function(
459+
query_instances=sample_custom_query_2,
460+
total_CFs=0,
461+
desired_class="opposite")
462+
463+
with pytest.raises(
464+
UserConfigValidationException,
465+
match=r"The posthoc_sparsity_algorithm should be linear or binary and not random"):
466+
explainer_function(query_instances=sample_custom_query_2,
467+
total_CFs=10,
468+
posthoc_sparsity_algorithm='random')
469+
470+
with pytest.raises(
471+
UserConfigValidationException,
472+
match=r"The posthoc_sparsity_algorithm should be linear or binary and not random"):
473+
explainer_function(query_instances=sample_custom_query_2,
474+
total_CFs=10,
475+
posthoc_sparsity_algorithm='random')
476+
477+
with pytest.raises(
478+
UserConfigValidationException,
479+
match=r'The stopping_threshold should lie between 0.0 and 1.0'):
480+
explainer_function(query_instances=sample_custom_query_2,
481+
total_CFs=10,
482+
stopping_threshold=-10.0)
483+
484+
with pytest.raises(
485+
UserConfigValidationException,
486+
match=r'The posthoc_sparsity_param should lie between 0.0 and 1.0'):
487+
explainer_function(query_instances=sample_custom_query_2,
488+
total_CFs=10,
489+
posthoc_sparsity_param=-10.0)
490+
491+
with pytest.raises(
492+
UserConfigValidationException,
493+
match=r'The desired_range parameter should not be set for classification task'):
494+
explainer_function(query_instances=sample_custom_query_2,
495+
total_CFs=10, desired_range=[0, 10])
496+
497+
@pytest.mark.parametrize('explainer_function',
498+
['generate_counterfactuals', 'local_feature_importance',
499+
'feature_importance', 'global_feature_importance'])
500+
def test_generate_counterfactuals_user_config_validations_regression(
501+
self, regression_exp_object, sample_custom_query_1,
502+
method, explainer_function):
503+
explainer_function = getattr(regression_exp_object, explainer_function)
504+
with pytest.raises(
505+
UserConfigValidationException,
506+
match=r'The desired_range parameter should be set for regression task'):
507+
explainer_function(query_instances=sample_custom_query_1,
508+
total_CFs=10)
509+
510+
def test_global_feature_importance_error_conditions_with_insufficient_query_points(
511+
self, method,
512+
sample_custom_query_1,
513+
custom_public_data_interface,
514+
sklearn_binary_classification_model_interface):
515+
exp = dice_ml.Dice(
516+
custom_public_data_interface,
517+
sklearn_binary_classification_model_interface,
518+
method=method)
519+
520+
cf_explanations = exp.generate_counterfactuals(
521+
query_instances=sample_custom_query_1,
522+
total_CFs=15)
523+
524+
with pytest.raises(
525+
UserConfigValidationException,
526+
match="The number of points for which counterfactuals generated should be "
527+
"greater than or equal to 10 "
528+
"to compute global feature importance"):
529+
exp.global_feature_importance(
530+
query_instances=None,
531+
cf_examples_list=cf_explanations.cf_examples_list)
532+
533+
with pytest.raises(
534+
UserConfigValidationException,
535+
match="The number of query instances should be greater than or equal to 10 "
536+
"to compute global feature importance over all query points"):
537+
exp.global_feature_importance(
538+
query_instances=sample_custom_query_1,
539+
total_CFs=15)
540+
541+
def test_global_feature_importance_error_conditions_with_insufficient_cfs_per_query_point(
542+
self, method,
543+
sample_custom_query_10,
544+
custom_public_data_interface,
545+
sklearn_binary_classification_model_interface):
546+
exp = dice_ml.Dice(
547+
custom_public_data_interface,
548+
sklearn_binary_classification_model_interface,
549+
method=method)
550+
551+
cf_explanations = exp.generate_counterfactuals(
552+
query_instances=sample_custom_query_10,
553+
total_CFs=1)
554+
555+
with pytest.raises(
556+
UserConfigValidationException,
557+
match="The number of counterfactuals generated per query instance should be "
558+
"greater than or equal to 10 "
559+
"to compute global feature importance over all query points"):
560+
exp.global_feature_importance(
561+
query_instances=None,
562+
cf_examples_list=cf_explanations.cf_examples_list)
563+
564+
with pytest.raises(
565+
UserConfigValidationException,
566+
match="The number of counterfactuals requested per query instance should be greater "
567+
"than or equal to 10 "
568+
"to compute global feature importance over all query points"):
569+
exp.global_feature_importance(
570+
query_instances=sample_custom_query_10,
571+
total_CFs=1)
572+
573+
def test_local_feature_importance_error_conditions_with_insufficient_cfs_per_query_point(
574+
self, method,
575+
sample_custom_query_1,
576+
custom_public_data_interface,
577+
sklearn_binary_classification_model_interface):
578+
exp = dice_ml.Dice(
579+
custom_public_data_interface,
580+
sklearn_binary_classification_model_interface,
581+
method=method)
582+
583+
cf_explanations = exp.generate_counterfactuals(
584+
query_instances=sample_custom_query_1,
585+
total_CFs=1)
586+
587+
with pytest.raises(
588+
UserConfigValidationException,
589+
match="The number of counterfactuals generated per query instance should be "
590+
"greater than or equal to 10 to compute feature importance for all query points"):
591+
exp.local_feature_importance(
592+
query_instances=None,
593+
cf_examples_list=cf_explanations.cf_examples_list)
594+
595+
with pytest.raises(
596+
UserConfigValidationException,
597+
match="The number of counterfactuals requested per "
598+
"query instance should be greater than or equal to 10 "
599+
"to compute feature importance for all query points"):
600+
exp.local_feature_importance(
601+
query_instances=sample_custom_query_1,
602+
total_CFs=1)
603+
604+
# class TestExplainerBaseDataValidations:

0 commit comments

Comments
 (0)