Skip to content

Commit 0b826ae

Browse files
committed
Fix #216: Allow unequal sample sizes in multi-group paired wide-format data
Remove overly aggressive NaN filtering in _check_errors() that was causing data truncation when using wide-format paired data with different group sizes. Problem: When loading wide-format paired data created by concatenating DataFrames of different lengths (e.g., 20, 10, and 40 samples), the package was removing ALL rows with ANY NaN value across ALL columns. This truncated all groups to the size of the smallest group. Root Cause: In _check_errors() method, the code had: elif x is None and y is None: self.__output_data.dropna(inplace=True) This removed entire rows if they had NaN in ANY column, affecting all groups even though NaN values were structural (from DataFrame concatenation) and not actual missing data points. Solution: Removed the problematic elif block from _check_errors(). The downstream code in _get_plot_data() already handles NaN values correctly by: 1. Using pd.melt() which preserves all non-NaN values 2. Calling dropna(subset=[self.__yvar]) which only removes rows with NaN in the measurement column, not across all columns Testing: - Added test_33_multi_paired_different_sizes() to verify groups with 20, 10, and 40 samples are preserved correctly
1 parent ee406ce commit 0b826ae

6 files changed

Lines changed: 30 additions & 6 deletions

dabest/_dabest_object.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -559,14 +559,12 @@ def _check_errors(self, x, y, idx, experiment, experiment_label, x1_level):
559559
self.__x1_level = x1_level
560560

561561
if self.__is_paired and self.__output_data.isnull().values.any():
562-
warn1 = f"NaN values detected under paired setting and removed,"
562+
warn1 = f"NaN values detected under paired setting,"
563563
warn2 = f" please check your data."
564564
warnings.warn(warn1 + warn2)
565565
if x is not None and y is not None:
566566
rmname = self.__output_data[self.__output_data[y].isnull()][self.__id_col].tolist()
567567
self.__output_data = self.__output_data[~self.__output_data[self.__id_col].isin(rmname)]
568-
elif x is None and y is None:
569-
self.__output_data.dropna(inplace=True)
570568

571569
# Check if there is a typo on paired
572570
if self.__is_paired and self.__is_paired not in ("baseline", "sequential"):

nbs/API/dabest_object.ipynb

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -664,14 +664,12 @@
664664
" self.__x1_level = x1_level\n",
665665
"\n",
666666
" if self.__is_paired and self.__output_data.isnull().values.any():\n",
667-
" warn1 = f\"NaN values detected under paired setting and removed,\"\n",
667+
" warn1 = f\"NaN values detected under paired setting,\"\n",
668668
" warn2 = f\" please check your data.\"\n",
669669
" warnings.warn(warn1 + warn2)\n",
670670
" if x is not None and y is not None:\n",
671671
" rmname = self.__output_data[self.__output_data[y].isnull()][self.__id_col].tolist()\n",
672672
" self.__output_data = self.__output_data[~self.__output_data[self.__id_col].isin(rmname)]\n",
673-
" elif x is None and y is None:\n",
674-
" self.__output_data.dropna(inplace=True)\n",
675673
"\n",
676674
" # Check if there is a typo on paired\n",
677675
" if self.__is_paired and self.__is_paired not in (\"baseline\", \"sequential\"):\n",
1 Byte
Loading
1 Byte
Loading
67.5 KB
Loading

nbs/tests/mpl_image_tests/test_03_plotting.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,34 @@ def test_32_multigroups_baseline_change_palette():
451451
plt.rcdefaults()
452452
return multi_groups_baseline.mean_diff.plot(custom_palette="Dark2", delta_text=True)
453453

454+
@pytest.mark.mpl_image_compare(tolerance=8)
455+
def test_33_multi_paired_different_sizes():
456+
# Test for GitHub issue #216: multi-group paired data with different sample sizes
457+
plt.rcdefaults()
458+
np.random.seed(9999)
459+
460+
# Create three test pairs with different sample sizes (20, 10, 40)
461+
c1DF = pd.DataFrame({'Test 1_pre': norm.rvs(loc=3, scale=0.4, size=20)})
462+
t1DF = pd.DataFrame({'Test 1_post': norm.rvs(loc=3.5, scale=0.5, size=20)})
463+
t2DF = pd.DataFrame({'Test 2_pre': norm.rvs(loc=2.5, scale=0.6, size=10)})
464+
t3DF = pd.DataFrame({'Test 2_post': norm.rvs(loc=3, scale=0.75, size=10)})
465+
t4DF = pd.DataFrame({'Test 3_pre': norm.rvs(loc=3.5, scale=0.75, size=40)})
466+
t5DF = pd.DataFrame({'Test 3_post': norm.rvs(loc=3.25, scale=0.4, size=40)})
467+
468+
df = pd.concat([c1DF, t1DF, t2DF, t3DF, t4DF, t5DF], axis=1)
469+
df["ID"] = pd.Series(range(1, len(df)+1))
470+
471+
multi_paired_diff_sizes = load(
472+
df,
473+
idx=(("Test 1_pre", "Test 1_post"),
474+
("Test 2_pre", "Test 2_post"),
475+
("Test 3_pre", "Test 3_post")),
476+
paired="baseline",
477+
id_col="ID"
478+
)
479+
480+
return multi_paired_diff_sizes.mean_diff.plot()
481+
454482
@pytest.mark.mpl_image_compare(tolerance=8)
455483
def test_99_style_sheets():
456484
# Perform this test last so we don't have to reset the plot style.

0 commit comments

Comments
 (0)