Skip to content

Commit 2e30bb3

Browse files
committed
Merge branch 'master' into gaugup/ReplaceBostonHousingDataset
2 parents 102c6cc + 428d125 commit 2e30bb3

7 files changed

Lines changed: 47 additions & 13 deletions

File tree

.github/workflows/pythonpackage.yml

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
runs-on: ${{ matrix.os }}
1616
strategy:
1717
matrix:
18-
python-version: [3.6, 3.7, 3.8, 3.9]
18+
python-version: [3.6, 3.7, 3.8, 3.9, "3.10"]
1919
os: [ubuntu-latest, macos-latest]
2020
exclude:
2121
- os: macos-latest
@@ -47,13 +47,18 @@ jobs:
4747
if: always()
4848
with:
4949
files: junit/test-results.xml
50-
- name: Upload code coverage results
51-
uses: actions/upload-artifact@v2
50+
- name: Upload coverage to Codecov
51+
uses: codecov/codecov-action@v2
5252
with:
53-
name: code-coverage-results
54-
path: htmlcov
55-
# Use always() to always run this step to publish test results when there are test failures
56-
if: ${{ always() }}
53+
token: ${{ secrets.CODECOV_TOKEN }}
54+
directory: .
55+
env_vars: OS,PYTHON
56+
fail_ci_if_error: true
57+
files: ./coverage.xml
58+
flags: unittests
59+
name: codecov-umbrella
60+
path_to_write_report: ./coverage/codecov_report.txt
61+
verbose: true
5762
- name: Check package consistency with twine
5863
run: |
5964
python setup.py check sdist bdist_wheel

dice_ml/explainer_interfaces/dice_KD.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,6 @@ def _generate_counterfactuals(self, query_instance, total_CFs, desired_range=Non
8989

9090
query_instance[self.data_interface.outcome_name] = test_pred
9191
desired_class = self.misc_init(stopping_threshold, desired_class, desired_range, test_pred)
92-
if desired_range is not None:
93-
if desired_range[0] > desired_range[1]:
94-
raise ValueError("Invalid Range!")
9592

9693
if desired_class == "opposite" and self.model.model_type == ModelTypes.Classifier:
9794
if self.num_output_nodes == 2:

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ def _validate_counterfactual_configuration(
8181
raise UserConfigValidationException(
8282
'The desired_range parameter should be set for regression task')
8383

84+
if desired_range is not None:
85+
if len(desired_range) != 2:
86+
raise UserConfigValidationException(
87+
"The parameter desired_range needs to have two numbers in ascending order.")
88+
if desired_range[0] > desired_range[1]:
89+
raise UserConfigValidationException(
90+
"The range provided in desired_range should be in ascending order.")
91+
8492
def generate_counterfactuals(self, query_instances, total_CFs,
8593
desired_class="opposite", desired_range=None,
8694
permitted_range=None, features_to_vary="all",
@@ -96,7 +104,8 @@ def generate_counterfactuals(self, query_instances, total_CFs,
96104
:param desired_class: Desired counterfactual class - can take 0 or 1. Default value
97105
is "opposite" to the outcome class of query_instance for binary classification.
98106
:param desired_range: For regression problems. Contains the outcome range to
99-
generate counterfactuals in.
107+
generate counterfactuals in. This should be a list of two numbers in
108+
ascending order.
100109
:param permitted_range: Dictionary with feature names as keys and permitted range in list as values.
101110
Defaults to the range inferred from training data.
102111
If None, uses the parameters initialized in data_interface.

requirements-deeplearning.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
tensorflow>=1.13.0-rc1
2-
torch
2+
torch; python_version < '3.10'

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"Programming Language :: Python :: 3.7",
3939
"Programming Language :: Python :: 3.8",
4040
"Programming Language :: Python :: 3.9",
41+
"Programming Language :: Python :: 3.10",
4142
"License :: OSI Approved :: MIT License",
4243
"Operating System :: OS Independent",
4344
],

tests/test_dice_interface/test_explainer_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,18 @@ def test_generate_counterfactuals_user_config_validations_regression(
510510
explainer_function(query_instances=sample_custom_query_1,
511511
total_CFs=10)
512512

513+
with pytest.raises(
514+
UserConfigValidationException,
515+
match=r'The parameter desired_range needs to have two numbers in ascending order.'):
516+
explainer_function(query_instances=sample_custom_query_1,
517+
total_CFs=10, desired_range=[1, 3, 4])
518+
519+
with pytest.raises(
520+
UserConfigValidationException,
521+
match=r'The range provided in desired_range should be in ascending order.'):
522+
explainer_function(query_instances=sample_custom_query_1,
523+
total_CFs=10, desired_range=[4, 3])
524+
513525
def test_global_feature_importance_error_conditions_with_insufficient_query_points(
514526
self, method,
515527
sample_custom_query_1,

tests/test_notebooks.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import os
66
import subprocess
7+
import sys
78
import tempfile
89

910
import nbformat
@@ -16,7 +17,11 @@
1617
"DiCE_with_advanced_options.ipynb", # requires tensorflow 1.x
1718
"DiCE_getting_started_feasible.ipynb", # needs changes after latest refactor
1819
"Benchmarking_different_CF_explanation_methods.ipynb"
19-
]
20+
]
21+
# notebooks that don't need to run on python 3.10
22+
torch_notebooks_not_3_10 = [
23+
"DiCE_getting_started.ipynb"
24+
]
2025

2126
# Adding the dice root folder to the python path so that jupyter notebooks
2227
if 'PYTHONPATH' not in os.environ:
@@ -77,6 +82,11 @@ def _notebook_run(filepath):
7782
nb,
7883
marks=[pytest.mark.skip, pytest.mark.advanced],
7984
id=nb)
85+
elif sys.version_info >= (3, 10) and nb in torch_notebooks_not_3_10:
86+
param = pytest.param(
87+
nb,
88+
marks=[pytest.mark.skip, pytest.mark.advanced],
89+
id=nb)
8090
else:
8191
param = pytest.param(nb, id=nb)
8292
parameter_list.append(param)

0 commit comments

Comments
 (0)