|
1 | 1 | import pandas as pd |
2 | 2 | import pytest |
| 3 | +from sklearn.ensemble import RandomForestRegressor |
3 | 4 |
|
| 5 | +import dice_ml |
4 | 6 | from dice_ml.utils.exception import UserConfigValidationException |
5 | 7 | from dice_ml.explainer_interfaces.explainer_base import ExplainerBase |
6 | 8 |
|
@@ -157,16 +159,41 @@ def test_zero_totalcfs(self, desired_class, multi_classification_exp_object, sam |
157 | 159 |
|
158 | 160 | class TestExplainerBaseRegression: |
159 | 161 |
|
160 | | - @pytest.mark.parametrize("desired_class, regression_exp_object", |
161 | | - [(1, 'random'), (1, 'genetic'), (1, 'kdtree')], |
| 162 | + @pytest.mark.parametrize("desired_range, regression_exp_object", |
| 163 | + [([10, 100], 'random'), ([10, 100], 'genetic'), ([10, 100], 'kdtree')], |
162 | 164 | indirect=['regression_exp_object']) |
163 | | - def test_zero_totalcfs(self, desired_class, regression_exp_object, sample_custom_query_1): |
| 165 | + def test_zero_totalcfs(self, desired_range, regression_exp_object, sample_custom_query_1): |
164 | 166 | exp = regression_exp_object # explainer object |
165 | 167 | with pytest.raises(UserConfigValidationException): |
166 | 168 | exp.generate_counterfactuals( |
167 | 169 | query_instances=[sample_custom_query_1], |
168 | 170 | total_CFs=0, |
169 | | - desired_class=desired_class) |
| 171 | + desired_range=desired_range) |
| 172 | + |
| 173 | + @pytest.mark.parametrize("desired_range, method", |
| 174 | + [([10, 100], 'random')]) |
| 175 | + def test_numeric_categories(self, desired_range, method, create_boston_data): |
| 176 | + x_train, x_test, y_train, y_test, feature_names = \ |
| 177 | + create_boston_data |
| 178 | + |
| 179 | + rfc = RandomForestRegressor(n_estimators=10, max_depth=4, |
| 180 | + random_state=777) |
| 181 | + model = rfc.fit(x_train, y_train) |
| 182 | + |
| 183 | + dataset_train = x_train.copy() |
| 184 | + dataset_train['Outcome'] = y_train |
| 185 | + feature_names.remove('CHAS') |
| 186 | + |
| 187 | + d = dice_ml.Data(dataframe=dataset_train, continuous_features=feature_names, outcome_name='Outcome') |
| 188 | + m = dice_ml.Model(model=model, backend='sklearn', model_type='regressor') |
| 189 | + exp = dice_ml.Dice(d, m, method=method) |
| 190 | + |
| 191 | + cf_explanation = exp.generate_counterfactuals( |
| 192 | + query_instances=x_test.iloc[0:1], |
| 193 | + total_CFs=10, |
| 194 | + desired_range=desired_range) |
| 195 | + |
| 196 | + assert cf_explanation is not None |
170 | 197 |
|
171 | 198 |
|
172 | 199 | class TestExplainerBase: |
|
0 commit comments