Skip to content

Commit 48e111a

Browse files
committed
explicitly check for duplicate columns
1 parent e704189 commit 48e111a

1 file changed

Lines changed: 22 additions & 3 deletions

File tree

dabest/_classes.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,24 @@ def __init__(self, data, idx, x, y, paired, id_col, ci, resamples,
4141
if all([isinstance(i, str) for i in idx]):
4242
# flatten out idx.
4343
all_plot_groups = pd.unique([t for t in idx]).tolist()
44+
if len(idx) > len(all_plot_groups):
45+
err0 = '`idx` contains duplicated groups. Please remove any duplicates and try again.'
46+
raise ValueError(err0)
47+
4448
# We need to re-wrap this idx inside another tuple so as to
4549
# easily loop thru each pairwise group later on.
4650
self.__idx = (idx,)
4751

4852
elif all([isinstance(i, (tuple, list)) for i in idx]):
4953
all_plot_groups = pd.unique([tt for t in idx for tt in t]).tolist()
54+
55+
actual_groups_given = sum([len(i) for i in idx])
56+
57+
if actual_groups_given > len(all_plot_groups):
58+
err0 = 'Groups are repeated across tuples,'
59+
err1 = ' or a tuple has repeated groups in it.'
60+
err2 = ' Please remove any duplicates and try again.'
61+
raise ValueError(err0 + err1 + err2)
5062

5163
else: # mix of string and tuple?
5264
err = 'There seems to be a problem with the idx you'
@@ -91,9 +103,14 @@ def __init__(self, data, idx, x, y, paired, id_col, ci, resamples,
91103
# check all the idx can be found in data_in[x]
92104
for g in all_plot_groups:
93105
if g not in data_in[x].unique():
94-
raise IndexError('{0} is not a group in `{1}`.'.format(g, x))
106+
err0 = '"{0}" is not a group in the column `{1}`.'.format(g, x)
107+
err1 = " Please check `idx` and try again."
108+
raise IndexError(err0 + err1)
95109

110+
# Select only rows where the value in the `x` column
111+
# is found in `idx`.
96112
plot_data = data_in[data_in.loc[:, x].isin(all_plot_groups)].copy()
113+
97114
# plot_data.drop("index", inplace=True, axis=1)
98115

99116
# Assign attributes
@@ -113,8 +130,10 @@ def __init__(self, data, idx, x, y, paired, id_col, ci, resamples,
113130
# First, check we have all columns in the dataset.
114131
for g in all_plot_groups:
115132
if g not in data_in.columns:
116-
raise IndexError('{0} is not a column in `data`.'.format(g))
117-
133+
err0 = '"{0}" is not a column in `data`.'.format(g)
134+
err1 = " Please check `idx` and try again."
135+
raise IndexError(err0 + err1)
136+
118137
set_all_columns = set(data_in.columns.tolist())
119138
set_all_plot_groups = set(all_plot_groups)
120139
id_vars = set_all_columns.difference(set_all_plot_groups)

0 commit comments

Comments
 (0)