Skip to content

Commit e69f7ca

Browse files
authored
added userconfig exception for totalcfs<1 (#125)
* added userconfig exception for totalcfs<1 * added unit test for explainer base
1 parent f95c777 commit e69f7ca

4 files changed

Lines changed: 74 additions & 2 deletions

File tree

dice_ml/explainer_interfaces/dice_random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def _generate_counterfactuals(self, query_instance, total_CFs, desired_range, d
163163
m, 'min %02d' % s, 'sec')
164164
else:
165165
if self.total_cfs_found == 0 :
166-
print('No Counterfactuals found for the given configuation, perhaps try with different parameters...', '; total time taken: %02d' % m, 'min %02d' % s, 'sec')
166+
print('No Counterfactuals found for the given configuration, perhaps try with different parameters...', '; total time taken: %02d' % m, 'min %02d' % s, 'sec')
167167
else:
168168
print('Only %d (required %d) Diverse Counterfactuals found for the given configuration, perhaps try with different parameters...' % (self.total_cfs_found, self.total_CFs), '; total time taken: %02d' % m, 'min %02d' % s, 'sec')
169169

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def generate_counterfactuals(self, query_instances, total_CFs,
6868
:returns: A CounterfactualExplanations object that contains the list of
6969
counterfactual examples per query_instance as one of its attributes.
7070
"""
71-
71+
if total_CFs <= 0:
72+
raise UserConfigValidationException("The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.")
7273
cf_examples_arr = []
7374
query_instances_list = []
7475
if isinstance(query_instances, pd.DataFrame):

tests/conftest.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,40 @@
44
import dice_ml
55
from dice_ml.utils import helpers
66

7+
@pytest.fixture
8+
def binary_classification_exp_object(method="random"):
9+
backend = 'sklearn'
10+
dataset = helpers.load_custom_testing_dataset_binary()
11+
d = dice_ml.Data(dataframe=dataset, continuous_features=['Numerical'], outcome_name='Outcome')
12+
ML_modelpath = helpers.get_custom_dataset_modelpath_pipeline_binary()
13+
m = dice_ml.Model(model_path=ML_modelpath, backend=backend)
14+
exp = dice_ml.Dice(d, m, method=method)
15+
return exp
16+
17+
18+
@pytest.fixture
19+
def multi_classification_exp_object(method="random"):
20+
backend = 'sklearn'
21+
dataset = helpers.load_custom_testing_dataset_multiclass()
22+
d = dice_ml.Data(dataframe=dataset, continuous_features=['Numerical'], outcome_name='Outcome')
23+
ML_modelpath = helpers.get_custom_dataset_modelpath_pipeline_multiclass()
24+
m = dice_ml.Model(model_path=ML_modelpath, backend=backend)
25+
exp = dice_ml.Dice(d, m, method=method)
26+
return exp
27+
28+
29+
@pytest.fixture
30+
def regression_exp_object(method="random"):
31+
backend = 'sklearn'
32+
dataset = helpers.load_custom_testing_dataset_regression()
33+
d = dice_ml.Data(dataframe=dataset, continuous_features=['Numerical'], outcome_name='Outcome')
34+
ML_modelpath = helpers.get_custom_dataset_modelpath_pipeline_regression()
35+
m = dice_ml.Model(model_path=ML_modelpath, backend=backend, model_type='regressor')
36+
exp = dice_ml.Dice(d, m, method=method)
37+
return exp
38+
39+
40+
741
@pytest.fixture
842
def public_data_object():
943
"""
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
from dice_ml.utils.exception import UserConfigValidationException
3+
4+
5+
class TestExplainerBaseBinaryClassification:
6+
7+
@pytest.mark.parametrize("desired_class, binary_classification_exp_object", [(1, 'random'),(1,'genetic'),(1,'kdtree')], indirect=['binary_classification_exp_object'])
8+
def test_zero_totalcfs(self, desired_class, binary_classification_exp_object, sample_custom_query_1):
9+
exp = binary_classification_exp_object # explainer object
10+
with pytest.raises(UserConfigValidationException):
11+
exp.generate_counterfactuals(
12+
query_instances=[sample_custom_query_1],
13+
total_CFs=0,
14+
desired_class=desired_class)
15+
16+
class TestExplainerBaseMultiClassClassification:
17+
18+
@pytest.mark.parametrize("desired_class, multi_classification_exp_object", [(1, 'random'),(1,'genetic'),(1,'kdtree')], indirect=['multi_classification_exp_object'])
19+
def test_zero_totalcfs(self, desired_class, multi_classification_exp_object, sample_custom_query_1):
20+
exp = multi_classification_exp_object # explainer object
21+
with pytest.raises(UserConfigValidationException):
22+
exp.generate_counterfactuals(
23+
query_instances=[sample_custom_query_1],
24+
total_CFs=0,
25+
desired_class=desired_class)
26+
27+
28+
class TestExplainerBaseRegression:
29+
30+
@pytest.mark.parametrize("desired_class, regression_exp_object", [(1, 'random'),(1,'genetic'),(1,'kdtree')], indirect=['regression_exp_object'])
31+
def test_zero_totalcfs(self, desired_class, regression_exp_object, sample_custom_query_1):
32+
exp = regression_exp_object # explainer object
33+
with pytest.raises(UserConfigValidationException):
34+
exp.generate_counterfactuals(
35+
query_instances=[sample_custom_query_1],
36+
total_CFs=0,
37+
desired_class=desired_class)

0 commit comments

Comments
 (0)