Skip to content

Commit fa968bc

Browse files
authored
Merge pull request #25 from fidelity/test_fix
update
2 parents 4e59100 + ffb37a9 commit fa968bc

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

jurity/fairness/for_difference.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010

1111
from jurity.fairness.base import _BaseBinaryFairness
1212
from jurity.utils import check_and_convert_list_types
13-
from jurity.utils import check_inputs_validity
13+
from jurity.utils import check_inputs
1414
from jurity.utils import performance_measures
1515
from jurity.utils import split_array_based_on_membership_label
1616

17+
1718
class FORDifference(_BaseBinaryFairness):
1819

1920
def __init__(self):
@@ -26,7 +27,7 @@ def __init__(self):
2627
@staticmethod
2728
def get_score(labels: Union[List, np.ndarray, pd.Series],
2829
predictions: Union[List, np.ndarray, pd.Series],
29-
is_member: Union[List, np.ndarray, pd.Series],
30+
memberships: Union[List, np.ndarray, pd.Series],
3031
membership_label: Union[str, float, int] = 1) -> float:
3132
"""
3233
The equality (or lack thereof) of the false omission rates across groups is an important fairness metric.
@@ -43,7 +44,7 @@ def get_score(labels: Union[List, np.ndarray, pd.Series],
4344
Binary ground truth labels for the provided dataset (0/1).
4445
predictions: Union[List, np.ndarray, pd.Series]
4546
Binary predictions from some black-box classifier (0/1).
46-
is_member: Union[List, np.ndarray, pd.Series]
47+
memberships: Union[List, np.ndarray, pd.Series]
4748
Binary membership labels (0/1).
4849
membership_label: Union[str, float, int]
4950
Value indicating group membership.
@@ -54,10 +55,11 @@ def get_score(labels: Union[List, np.ndarray, pd.Series],
5455
False Omission Rate difference between groups.
5556
"""
5657
# Logic to check input types.
57-
check_inputs_validity(labels=labels, predictions=predictions, is_member=is_member, optional_labels=False)
58+
check_inputs(predictions=predictions, memberships=memberships, membership_labels=membership_label,
59+
must_have_labels=True, labels=labels)
5860

5961
# List needs to be converted to np for indexing
60-
is_member = check_and_convert_list_types(is_member)
62+
is_member = check_and_convert_list_types(memberships)
6163
predictions = check_and_convert_list_types(predictions)
6264
labels = check_and_convert_list_types(labels)
6365

0 commit comments

Comments
 (0)