Skip to content

Commit 855ce7e

Browse files
authored
Enable flake8 linter in gates (#131)
* Enable flake8 linter in gates Signed-off-by: gaugup <gaugup@microsoft.com> * Fix flake8 errors in dice_ml/data.py Signed-off-by: gaugup <gaugup@microsoft.com> * Fix flake8 errors dice_ml/dice.py Signed-off-by: gaugup <gaugup@microsoft.com> * Fix flake8 errors in data_interfaces module Signed-off-by: gaugup <gaugup@microsoft.com> * Fix flake8 errors in data_interfaces module Signed-off-by: gaugup <gaugup@microsoft.com> * fix flake8 errors for dice_ml/model.py Signed-off-by: gaugup <gaugup@microsoft.com> * Fix flake8 errors in tests/ Signed-off-by: gaugup <gaugup@microsoft.com> * Fixed flake8 errors in model_interfaces module Signed-off-by: gaugup <gaugup@microsoft.com> * Fix flake8 errors in explainer_interfaces Signed-off-by: gaugup <gaugup@microsoft.com> * Fix flake8 errors in tests/ Signed-off-by: gaugup <gaugup@microsoft.com> * Fixed flake8 errors in helpers.py Signed-off-by: gaugup <gaugup@microsoft.com> * Fixed flake8 errors in docs/source/conf.py Signed-off-by: gaugup <gaugup@microsoft.com> * Fixed flake8 errors in dice_ml/utils/sample_architecture/vae_model.py Signed-off-by: gaugup <gaugup@microsoft.com> * Fixed flake8 errors in tests/ Signed-off-by: gaugup <gaugup@microsoft.com> * Reduce cyclometric complexity in dice_ml/counterfactual_explanations.py Signed-off-by: gaugup <gaugup@microsoft.com> * Fix flake8 errors in dice_ml/utils/sample_architecture/vae_model.py Signed-off-by: gaugup <gaugup@microsoft.com> * Fix flake8 in feasible_base_vae.py feasible_model_approx.py Signed-off-by: gaugup <gaugup@microsoft.com> * Fixed flake8 errors in dice_tensorflow1.py and dice_tensorflow2.py Signed-off-by: gaugup <gaugup@microsoft.com> * Fix flake8 errors in dice_pytorch.py and explainer_base.py Signed-off-by: gaugup <gaugup@microsoft.com> * Fix flake8 errors in diverse_counterfactuals.py, explainer_interfaces/dice_KD.py and explainer_interfaces/dice_genetic.py Signed-off-by: gaugup <gaugup@microsoft.com> * Fix flake8 errors in tests Signed-off-by: gaugup <gaugup@microsoft.com> * Fix remaining flake8 errors Signed-off-by: gaugup <gaugup@microsoft.com>
1 parent 90f0975 commit 855ce7e

39 files changed

Lines changed: 1475 additions & 1039 deletions

.github/workflows/pythonpackage.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
# stop the build if there are Python syntax errors or undefined names
3636
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
3737
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
38-
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
38+
flake8 . --count --max-complexity=30 --max-line-length=127 --statistics
3939
- name: Test with pytest
4040
run: |
4141
pytest

dice_ml/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
from .data import Data
22
from .model import Model
33
from .dice import Dice
4+
5+
__all__ = ["Data",
6+
"Model",
7+
"Dice"]

dice_ml/counterfactual_explanations.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,17 +216,21 @@ def to_json(self):
216216
raise UserConfigValidationException(
217217
"Unsupported serialization version {}".format(serialization_version))
218218

219+
@staticmethod
220+
def _validate_serialization_version(version):
221+
if version is None:
222+
raise UserConfigValidationException("No version field in the json input")
223+
elif not _check_supported_json_output_versions(version):
224+
raise UserConfigValidationException("Incompatible version {} found in json input".format(version))
225+
219226
@staticmethod
220227
def from_json(json_str):
221228
""" Deserialize json string to a CounterfactualExplanations object.
222229
"""
223230
json_dict = json.loads(json_str)
224231
if _CommonSchemaConstants.METADATA in json_dict:
225232
version = json_dict[_CommonSchemaConstants.METADATA].get('version')
226-
if version is None:
227-
raise UserConfigValidationException("No version field in the json input")
228-
elif not _check_supported_json_output_versions(version):
229-
raise UserConfigValidationException("Incompatible version {} found in json input".format(version))
233+
CounterfactualExplanations._validate_serialization_version(version)
230234

231235
if version == _SchemaVersions.V1:
232236
CounterfactualExplanations._check_cf_exp_output_against_json_schema(
@@ -240,7 +244,7 @@ def from_json(json_str):
240244
local_importance=json_dict[_CounterfactualExpV1SchemaConstants.LOCAL_IMPORTANCE],
241245
summary_importance=json_dict[_CounterfactualExpV1SchemaConstants.SUMMARY_IMPORTANCE],
242246
version=version)
243-
elif version == _SchemaVersions.V2:
247+
else:
244248
CounterfactualExplanations._check_cf_exp_output_against_json_schema(
245249
json_dict, version=version)
246250
cf_examples_list = []

dice_ml/data.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Module pointing to different implementations of Data class
22
3-
DiCE requires only few parameters about the data such as the range of continuous features and the levels of categorical features. Hence, DiCE can be used for a private data whose meta data are only available (such as the feature names and range/levels of different features) by specifying appropriate parameters.
3+
DiCE requires only few parameters about the data such as the range of continuous
4+
features and the levels of categorical features. Hence, DiCE can be used for a
5+
private data whose meta data are only available (such as the feature names and
6+
range/levels of different features) by specifying appropriate parameters.
47
"""
58

69

@@ -12,23 +15,26 @@ def __init__(self, **params):
1215
1316
:param **params: a dictionary of required parameters.
1417
"""
15-
1618
self.decide_implementation_type(params)
1719

1820
def decide_implementation_type(self, params):
1921
"""Decides if the Data class is for public or private data."""
20-
21-
self.__class__ = decide(params)
22+
self.__class__ = decide(params)
2223
self.__init__(params)
2324

24-
# To add new implementations of Data, add the class in data_interfaces subpackage and import-and-return the class in an elif loop as shown in the below method.
2525

2626
def decide(params):
27-
"""Decides if the Data class is for public or private data."""
28-
29-
if 'dataframe' in params: # if params contain a Pandas dataframe, then use PublicData class
27+
"""Decides if the Data class is for public or private data.
28+
29+
To add new implementations of Data, add the class in data_interfaces
30+
subpackage and import-and-return the class in an elif loop as shown
31+
in the below method.
32+
"""
33+
if 'dataframe' in params:
34+
# if params contain a Pandas dataframe, then use PublicData class
3035
from dice_ml.data_interfaces.public_data_interface import PublicData
3136
return PublicData
32-
else: # use PrivateData if only meta data is provided
37+
else:
38+
# use PrivateData if only meta data is provided
3339
from dice_ml.data_interfaces.private_data_interface import PrivateData
3440
return PrivateData

dice_ml/data_interfaces/private_data_interface.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,44 @@
33
import sys
44
import pandas as pd
55
import numpy as np
6-
from sklearn.model_selection import train_test_split
76
import collections
8-
from collections import OrderedDict
97
import logging
8+
9+
1010
logging.basicConfig(level=logging.NOTSET)
11-
from sklearn.preprocessing import LabelEncoder
11+
1212

1313
class PrivateData:
1414
"""A data interface for private data with meta information."""
1515

1616
def __init__(self, params):
1717
"""Init method
1818
19-
:param features: Dictionary or OrderedDict with feature names as keys and range in int/float (for continuous features) or categories in string (for categorical features) as values. For python version <=3.6, should provide only an OrderedDict.
19+
:param features: Dictionary or OrderedDict with feature names as keys and range in int/float
20+
(for continuous features) or categories in string (for categorical features)
21+
as values. For python version <=3.6, should provide only an OrderedDict.
2022
:param outcome_name: Outcome feature name.
21-
:param type_and_precision (optional): Dictionary with continuous feature names as keys. If the feature is of type int, just string 'int' should be provided, if the feature is of type float, a list of type and precision should be provided. For instance, type_and_precision: {cont_f1: 'int', cont_f2: ['float', 2]} for continuous features cont_f1 and cont_f2 of type int and float (and precision up to 2 decimal places) respectively. Default value is None and all features are treated as int.
22-
:param mad (optional): Dictionary with feature names as keys and corresponding Median Absolute Deviations (MAD) as values. Default MAD value is 1 for all features.
23+
:param type_and_precision (optional): Dictionary with continuous feature names as keys.
24+
If the feature is of type int, just string 'int' should be provided,
25+
if the feature is of type float, a list of type and precision should be
26+
provided. For instance, type_and_precision: {cont_f1: 'int',
27+
cont_f2: ['float', 2]} for continuous features cont_f1 and cont_f2 of
28+
type int and float (and precision up to 2 decimal places) respectively.
29+
Default value is None and all features are treated as int.
30+
:param mad (optional): Dictionary with feature names as keys and corresponding Median Absolute Deviations (MAD)
31+
as values.
32+
Default MAD value is 1 for all features.
2333
:param data_name (optional): Dataset name
24-
2534
"""
26-
27-
if sys.version_info > (3,6,0) and type(params['features']) in [dict, collections.OrderedDict]:
35+
if sys.version_info > (3, 6, 0) and type(params['features']) in [dict, collections.OrderedDict]:
2836
features_dict = params['features']
29-
elif sys.version_info <= (3,6,0) and type(params['features']) is collections.OrderedDict:
37+
elif sys.version_info <= (3, 6, 0) and type(params['features']) is collections.OrderedDict:
3038
features_dict = params['features']
3139
else:
3240
raise ValueError(
33-
"should provide dictionary with feature names as keys and range (for continuous features) or categories (for categorical features) as values. For python version <3.6, should provide an OrderedDict")
41+
"should provide dictionary with feature names as keys and range"
42+
"(for continuous features) or categories (for categorical features) as values. "
43+
"For python version <3.6, should provide an OrderedDict")
3444

3545
if type(params['outcome_name']) is str:
3646
self.outcome_name = params['outcome_name']
@@ -80,7 +90,8 @@ def __init__(self, params):
8090
#
8191
# for column in self.categorical_feature_names:
8292
# self.labelencoder[column] = LabelEncoder()
83-
# self.label_encoded_data[column] = self.labelencoder[column].fit_transform(self.categorical_levels[column])
93+
# self.label_encoded_data[column] = \
94+
# self.labelencoder[column].fit_transform(self.categorical_levels[column])
8495

8596
# self.max_range = -np.inf
8697
# for feature in self.continuous_feature_names:
@@ -178,8 +189,10 @@ def create_ohe_params(self):
178189
# one-hot-encoded data is same as original data if there is no categorical features.
179190
self.ohe_encoded_feature_names = [feat for feat in self.feature_names]
180191

181-
self.ohe_base_df = self.prepare_df_for_ohe_encoding() # base dataframe for doing one-hot-encoding
182-
# ohe_encoded_feature_names and ohe_base_df are created (and stored as data class's parameters) when get_data_params_for_gradient_dice() is called from gradient-based DiCE explainers
192+
# base dataframe for doing one-hot-encoding
193+
# ohe_encoded_feature_names and ohe_base_df are created (and stored as data class's parameters)
194+
# when get_data_params_for_gradient_dice() is called from gradient-based DiCE explainers
195+
self.ohe_base_df = self.prepare_df_for_ohe_encoding()
183196

184197
def get_data_params_for_gradient_dice(self):
185198
"""Gets all data related params for DiCE."""
@@ -200,8 +213,8 @@ def get_data_params_for_gradient_dice(self):
200213
# decimal precisions for continuous features
201214
cont_precisions = [self.get_decimal_precisions()[ix] for ix in range(len(self.continuous_feature_names))]
202215

203-
return minx, maxx, encoded_categorical_feature_indexes, encoded_continuous_feature_indexes, cont_minx, cont_maxx, cont_precisions
204-
216+
return minx, maxx, encoded_categorical_feature_indexes, encoded_continuous_feature_indexes, \
217+
cont_minx, cont_maxx, cont_precisions
205218

206219
def get_encoded_categorical_feature_indexes(self):
207220
"""Gets the column indexes categorical features after one-hot-encoding."""
@@ -243,10 +256,10 @@ def from_label(self, data):
243256
def from_dummies(self, data, prefix_sep='_'):
244257
"""Gets the original data from dummy encoded data with k levels."""
245258
out = data.copy()
246-
for l in self.categorical_feature_names:
259+
for feature_name in self.categorical_feature_names:
247260
cols, labs = [[c.replace(
248-
x, "") for c in data.columns if l+prefix_sep in c] for x in ["", l+prefix_sep]]
249-
out[l] = pd.Categorical(
261+
x, "") for c in data.columns if feature_name+prefix_sep in c] for x in ["", feature_name+prefix_sep]]
262+
out[feature_name] = pd.Categorical(
250263
np.array(labs)[np.argmax(data[cols].values, axis=1)])
251264
out.drop(cols, axis=1, inplace=True)
252265
return out
@@ -330,19 +343,23 @@ def prepare_query_instance(self, query_instance):
330343
return test
331344

332345
def get_ohe_min_max_normalized_data(self, query_instance):
333-
"""Transforms query_instance into one-hot-encoded and min-max normalized data. query_instance should be a dict, a dataframe, a list, or a list of dicts"""
346+
"""Transforms query_instance into one-hot-encoded and min-max normalized data. query_instance should be a dict,
347+
a dataframe, a list, or a list of dicts"""
334348
query_instance = self.prepare_query_instance(query_instance)
335349
temp = self.ohe_base_df.append(query_instance, ignore_index=True, sort=False)
336350
temp = self.one_hot_encode_data(temp)
337351
temp = temp.tail(query_instance.shape[0]).reset_index(drop=True)
338-
return self.normalize_data(temp) # returns a pandas dataframe
352+
# returns a pandas dataframe
353+
return self.normalize_data(temp)
339354

340355
def get_inverse_ohe_min_max_normalized_data(self, transformed_data):
341-
"""Transforms one-hot-encoded and min-max normalized data into raw user-fed data format. transformed_data should be a dataframe or an array"""
356+
"""Transforms one-hot-encoded and min-max normalized data into raw user-fed data format. transformed_data
357+
should be a dataframe or an array"""
342358
raw_data = self.get_decoded_data(transformed_data, encoding='one-hot')
343359
raw_data = self.de_normalize_data(raw_data)
344360
precisions = self.get_decimal_precisions()
345361
for ix, feature in enumerate(self.continuous_feature_names):
346362
raw_data[feature] = raw_data[feature].astype(float).round(precisions[ix])
347363
raw_data = raw_data[self.feature_names]
348-
return raw_data # returns a pandas dataframe
364+
# returns a pandas dataframe
365+
return raw_data

0 commit comments

Comments
 (0)