|
1 | 1 | """Module containing all required information about the interface between raw (or transformed) |
2 | 2 | public data and DiCE explainers.""" |
3 | 3 |
|
4 | | -import pandas as pd |
5 | | -import numpy as np |
6 | 4 | import logging |
7 | 5 | from collections import defaultdict |
8 | 6 |
|
| 7 | +import numpy as np |
| 8 | +import pandas as pd |
| 9 | + |
9 | 10 | from dice_ml.data_interfaces.base_data_interface import _BaseData |
10 | | -from dice_ml.utils.exception import SystemException, UserConfigValidationException |
| 11 | +from dice_ml.utils.exception import (SystemException, |
| 12 | + UserConfigValidationException) |
11 | 13 |
|
12 | 14 |
|
13 | 15 | class PublicData(_BaseData): |
@@ -147,6 +149,37 @@ def _validate_and_set_permitted_range(self, params): |
147 | 149 | ) |
148 | 150 | self.permitted_range, _ = self.get_features_range(input_permitted_range) |
149 | 151 |
|
| 152 | + def check_features_to_vary(self, features_to_vary): |
| 153 | + if features_to_vary is not None and features_to_vary != 'all': |
| 154 | + not_training_features = set(features_to_vary) - set(self.feature_names) |
| 155 | + if len(not_training_features) > 0: |
| 156 | + raise UserConfigValidationException("Got features {0} which are not present in training data".format( |
| 157 | + not_training_features)) |
| 158 | + |
| 159 | + def check_permitted_range(self, permitted_range): |
| 160 | + if permitted_range is not None: |
| 161 | + permitted_range_features = list(permitted_range) |
| 162 | + not_training_features = set(permitted_range_features) - set(self.feature_names) |
| 163 | + if len(not_training_features) > 0: |
| 164 | + raise UserConfigValidationException("Got features {0} which are not present in training data".format( |
| 165 | + not_training_features)) |
| 166 | + |
| 167 | + for feature in permitted_range_features: |
| 168 | + if feature in self.categorical_feature_names: |
| 169 | + train_categories = self.permitted_range[feature] |
| 170 | + for test_category in permitted_range[feature]: |
| 171 | + if test_category not in train_categories: |
| 172 | + raise UserConfigValidationException( |
| 173 | + 'The category {0} does not occur in the training data for feature {1}.' |
| 174 | + ' Allowed categories are {2}'.format(test_category, feature, train_categories)) |
| 175 | + |
| 176 | + def check_mad_validity(self, feature_weights): |
| 177 | + """checks feature MAD validity and throw warnings. |
| 178 | + TODO: add comments as to where this is used if this function is necessary, else remove. |
| 179 | + """ |
| 180 | + if feature_weights == "inverse_mad": |
| 181 | + self.get_valid_mads(display_warnings=True, return_mads=False) |
| 182 | + |
150 | 183 | def get_features_range(self, permitted_range_input=None): |
151 | 184 | ranges = {} |
152 | 185 | # Getting default ranges based on the dataset |
|
0 commit comments