Skip to content

Commit d4d51aa

Browse files
committed
Merge branch 'master' into gaugup/AddValidationsToGenerateCounterfactuals
2 parents b46310c + c234226 commit d4d51aa

98 files changed

Lines changed: 6495 additions & 3847 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# This workflow will lint python code with flake8 and flake8-nb.
2+
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3+
4+
name: Python linting
5+
6+
on:
7+
push:
8+
branches: [ master ]
9+
pull_request:
10+
branches: [ master ]
11+
12+
jobs:
13+
build:
14+
runs-on: ubuntu-latest
15+
16+
steps:
17+
- uses: actions/checkout@v2
18+
- name: Set up Python 3.7
19+
uses: actions/setup-python@v2
20+
with:
21+
python-version: 3.7
22+
- name: Install dependencies
23+
run: |
24+
python -m pip install --upgrade pip
25+
python -m pip install flake8==3.9.2 flake8-nb==0.3.0 isort
26+
- name: Check sorted python imports using isort
27+
run: |
28+
isort . -c
29+
- name: Lint code with flake8
30+
run: |
31+
# stop the build if there are Python syntax errors or undefined names
32+
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
33+
# The GitHub editor is 127 chars wide.
34+
flake8 . --count --max-complexity=30 --max-line-length=127 --statistics
35+
# Check for cyclometric complexity for specific files where this metric has been
36+
# reduced to ten and below
37+
flake8 dice_ml/data_interfaces/ --count --max-complexity=10 --max-line-length=127
38+
- name: Lint notebooks with flake8_nb
39+
run: |
40+
# stop the build if there are flake8 errors in notebooks
41+
flake8_nb docs/source/notebooks/ --statistics --max-line-length=127

.github/workflows/pythonpackage.yml

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
22
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
33

4-
name: Python package
4+
name: Python package test
55

66
on:
77
push:
@@ -24,25 +24,18 @@ jobs:
2424
uses: actions/setup-python@v1
2525
with:
2626
python-version: ${{ matrix.python-version }}
27-
- name: Install dependencies
27+
- name: Upgrade pip
2828
run: |
2929
python -m pip install --upgrade pip
30-
pip install -r requirements-test.txt
30+
- name: Install core dependencies
31+
run: |
3132
pip install -r requirements.txt
32-
pip install -r requirements-deeplearning.txt
33-
- name: Lint code with flake8
33+
- name: Install deep learning dependencies
3434
run: |
35-
# stop the build if there are Python syntax errors or undefined names
36-
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
37-
# The GitHub editor is 127 chars wide.
38-
flake8 . --count --max-complexity=30 --max-line-length=127 --statistics
39-
# Check for cyclometric complexity for specific files where this metric has been
40-
# reduced to ten and below
41-
flake8 dice_ml/data_interfaces/ --count --max-complexity=10 --max-line-length=127
42-
- name: Lint notebooks with flake8_nb
35+
pip install -r requirements-deeplearning.txt
36+
- name: Install test dependencies
4337
run: |
44-
# stop the build if there are flake8 errors in notebooks
45-
flake8_nb docs/source/notebooks/ --statistics --max-line-length=127
38+
pip install -r requirements-test.txt
4639
- name: Test with pytest
4740
run: |
4841
# pytest

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,5 +113,6 @@ docs/notebooks/DiCE_getting_started.ipynb
113113
docs/notebooks/DiCE_getting_started_feasible.ipynb
114114
docs/notebooks/DiCE_with_advanced_options.ipynb
115115
docs/notebooks/DiCE_with_private_data.ipynb
116+
docs/notebooks/*.ipynb
116117

117118

CODEOWNERS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# dice-ml package
2+
/dice_ml @gaugup @amit-sharma

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
include requirements.txt
22
include requirements-deeplearning.txt
3+
include requirements-test.txt
34
include LICENSE
45
recursive-include docs *
56
recursive-include tests *.py

README.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ If you face any problems, try installing dependencies manually.
6666
6767
DiCE requires the following packages:
6868

69+
* jsonschema
6970
* numpy
7071
* scikit-learn
7172
* pandas
@@ -119,6 +120,18 @@ Using DiCE, we can now generate examples that would have been classified as clas
119120
:width: 400
120121
:alt: List of counterfactual examples
121122

123+
You can save the generated counterfactual examples in the following way:-
124+
125+
.. code:: python
126+
127+
# Generate counterfactual examples
128+
dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=4, desired_class="opposite")
129+
# Visualize counterfactual explanation
130+
dice_exp.visualize_as_dataframe()
131+
# Save generated counterfactual examples to disk
132+
dice_exp.cf_examples_list[0].final_cfs_df.to_csv(path_or_buf='counterfactuals.csv', index=False)
133+
134+
122135
For more details, check out the `docs/source/notebooks <https://github.com/interpretml/DiCE/tree/master/docs/source/notebooks>`_ folder. Here are some example notebooks:
123136

124137
* `Getting Started <https://github.com/interpretml/DiCE/blob/master/docs/source/notebooks/DiCE_getting_started.ipynb>`_: Generate CF examples for a `sklearn`, `tensorflow` or `pytorch` binary classifier and compute feature importance scores.

dice_ml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .data import Data
2-
from .model import Model
32
from .dice import Dice
3+
from .model import Model
44

55
__all__ = ["Data",
66
"Model",

dice_ml/counterfactual_explanations.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import json
2-
import jsonschema
32
import os
43

5-
from dice_ml.diverse_counterfactuals import CounterfactualExamples
6-
from dice_ml.utils.exception import UserConfigValidationException
7-
from dice_ml.diverse_counterfactuals import _DiverseCFV2SchemaConstants
4+
import jsonschema
5+
86
from dice_ml.constants import _SchemaVersions
7+
from dice_ml.diverse_counterfactuals import (CounterfactualExamples,
8+
_DiverseCFV2SchemaConstants)
9+
from dice_ml.utils.exception import UserConfigValidationException
910

1011

1112
class _CommonSchemaConstants:
@@ -45,10 +46,10 @@ class CounterfactualExplanations:
4546
4647
:param cf_examples_list: A list of CounterfactualExamples instances
4748
:param local_importance: List of estimated local importance scores. The
48-
size of the list is the number of input instances, each containing feature
49-
importance scores for that input.
49+
size of the list is the number of input instances, each containing
50+
feature importance scores for that input.
5051
:param summary_importance: Estimated global feature importance scores
51-
based on the input set of CounterfactualExamples instances
52+
based on the input set of CounterfactualExamples instances
5253
5354
"""
5455
def __init__(self, cf_examples_list,
@@ -118,6 +119,7 @@ def _check_cf_exp_output_against_json_schema(
118119
119120
:param cf_dict: Serialized version of the counterfactual explanations.
120121
:type cf_dict: Dict
122+
121123
"""
122124
schema_file_name = 'counterfactual_explanations_v{0}.json'.format(version)
123125
schema_path = os.path.join(os.path.dirname(__file__),

dice_ml/data_interfaces/private_data_interface.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
"""Module containing meta data information about private data."""
22

3-
import sys
4-
import pandas as pd
5-
import numpy as np
63
import collections
74
import logging
5+
import sys
86

9-
from dice_ml.data_interfaces.base_data_interface import _BaseData
10-
7+
import numpy as np
8+
import pandas as pd
119

12-
logging.basicConfig(level=logging.NOTSET)
10+
from dice_ml.data_interfaces.base_data_interface import _BaseData
1311

1412

1513
class PrivateData(_BaseData):

dice_ml/data_interfaces/public_data_interface.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Module containing all required information about the interface between raw (or transformed)
22
public data and DiCE explainers."""
33

4-
import pandas as pd
5-
import numpy as np
64
import logging
75
from collections import defaultdict
86

7+
import numpy as np
8+
import pandas as pd
9+
910
from dice_ml.data_interfaces.base_data_interface import _BaseData
10-
from dice_ml.utils.exception import SystemException, UserConfigValidationException
11+
from dice_ml.utils.exception import (SystemException,
12+
UserConfigValidationException)
1113

1214

1315
class PublicData(_BaseData):
@@ -147,6 +149,37 @@ def _validate_and_set_permitted_range(self, params):
147149
)
148150
self.permitted_range, _ = self.get_features_range(input_permitted_range)
149151

152+
def check_features_to_vary(self, features_to_vary):
153+
if features_to_vary is not None and features_to_vary != 'all':
154+
not_training_features = set(features_to_vary) - set(self.feature_names)
155+
if len(not_training_features) > 0:
156+
raise UserConfigValidationException("Got features {0} which are not present in training data".format(
157+
not_training_features))
158+
159+
def check_permitted_range(self, permitted_range):
160+
if permitted_range is not None:
161+
permitted_range_features = list(permitted_range)
162+
not_training_features = set(permitted_range_features) - set(self.feature_names)
163+
if len(not_training_features) > 0:
164+
raise UserConfigValidationException("Got features {0} which are not present in training data".format(
165+
not_training_features))
166+
167+
for feature in permitted_range_features:
168+
if feature in self.categorical_feature_names:
169+
train_categories = self.permitted_range[feature]
170+
for test_category in permitted_range[feature]:
171+
if test_category not in train_categories:
172+
raise UserConfigValidationException(
173+
'The category {0} does not occur in the training data for feature {1}.'
174+
' Allowed categories are {2}'.format(test_category, feature, train_categories))
175+
176+
def check_mad_validity(self, feature_weights):
177+
"""checks feature MAD validity and throw warnings.
178+
TODO: add comments as to where this is used if this function is necessary, else remove.
179+
"""
180+
if feature_weights == "inverse_mad":
181+
self.get_valid_mads(display_warnings=True, return_mads=False)
182+
150183
def get_features_range(self, permitted_range_input=None):
151184
ranges = {}
152185
# Getting default ranges based on the dataset

0 commit comments

Comments
 (0)