Skip to content

Commit 3122320

Browse files
authored
Extend categorical value check to numeric values (#212)
* Extend categorical value check to numeric values Signed-off-by: gaugup <gaugup@microsoft.com> * Add unit tests Signed-off-by: gaugup <gaugup@microsoft.com>
1 parent 1f1ea41 commit 3122320

3 files changed

Lines changed: 38 additions & 7 deletions

File tree

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ def check_query_instance_validity(self, features_to_vary, permitted_range, query
168168
raise ValueError("Feature", feature, "not present in training data!")
169169

170170
for feature in self.data_interface.categorical_feature_names:
171-
if query_instance[feature].values[0] not in feature_ranges_orig[feature]:
171+
if query_instance[feature].values[0] not in feature_ranges_orig[feature] and \
172+
str(query_instance[feature].values[0]) not in feature_ranges_orig[feature]:
172173
raise ValueError("Feature", feature, "has a value outside the dataset.")
173174

174175
if feature not in features_to_vary and permitted_range is not None:

tests/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import pytest
21
from collections import OrderedDict
32
import pandas as pd
3+
import pytest
44
from sklearn.datasets import load_iris, load_boston
55
from sklearn.model_selection import train_test_split
6+
67
import dice_ml
78
from dice_ml.utils import helpers
89

@@ -194,4 +195,6 @@ def create_boston_data():
194195
x_train, x_test, y_train, y_test = train_test_split(
195196
boston.data, boston.target,
196197
test_size=0.2, random_state=7)
197-
return x_train, x_test, y_train, y_test, boston.feature_names
198+
x_train = pd.DataFrame(data=x_train, columns=boston.feature_names)
199+
x_test = pd.DataFrame(data=x_test, columns=boston.feature_names)
200+
return x_train, x_test, y_train, y_test, boston.feature_names.tolist()

tests/test_dice_interface/test_explainer_base.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pandas as pd
22
import pytest
3+
from sklearn.ensemble import RandomForestRegressor
34

5+
import dice_ml
46
from dice_ml.utils.exception import UserConfigValidationException
57
from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
68

@@ -157,16 +159,41 @@ def test_zero_totalcfs(self, desired_class, multi_classification_exp_object, sam
157159

158160
class TestExplainerBaseRegression:
159161

160-
@pytest.mark.parametrize("desired_class, regression_exp_object",
161-
[(1, 'random'), (1, 'genetic'), (1, 'kdtree')],
162+
@pytest.mark.parametrize("desired_range, regression_exp_object",
163+
[([10, 100], 'random'), ([10, 100], 'genetic'), ([10, 100], 'kdtree')],
162164
indirect=['regression_exp_object'])
163-
def test_zero_totalcfs(self, desired_class, regression_exp_object, sample_custom_query_1):
165+
def test_zero_totalcfs(self, desired_range, regression_exp_object, sample_custom_query_1):
164166
exp = regression_exp_object # explainer object
165167
with pytest.raises(UserConfigValidationException):
166168
exp.generate_counterfactuals(
167169
query_instances=[sample_custom_query_1],
168170
total_CFs=0,
169-
desired_class=desired_class)
171+
desired_range=desired_range)
172+
173+
@pytest.mark.parametrize("desired_range, method",
174+
[([10, 100], 'random')])
175+
def test_numeric_categories(self, desired_range, method, create_boston_data):
176+
x_train, x_test, y_train, y_test, feature_names = \
177+
create_boston_data
178+
179+
rfc = RandomForestRegressor(n_estimators=10, max_depth=4,
180+
random_state=777)
181+
model = rfc.fit(x_train, y_train)
182+
183+
dataset_train = x_train.copy()
184+
dataset_train['Outcome'] = y_train
185+
feature_names.remove('CHAS')
186+
187+
d = dice_ml.Data(dataframe=dataset_train, continuous_features=feature_names, outcome_name='Outcome')
188+
m = dice_ml.Model(model=model, backend='sklearn', model_type='regressor')
189+
exp = dice_ml.Dice(d, m, method=method)
190+
191+
cf_explanation = exp.generate_counterfactuals(
192+
query_instances=x_test.iloc[0:1],
193+
total_CFs=10,
194+
desired_range=desired_range)
195+
196+
assert cf_explanation is not None
170197

171198

172199
class TestExplainerBase:

0 commit comments

Comments
 (0)