Skip to content

Commit 3540abb

Browse files
committed
Rewrite delta-delta Bootstrap function
1. Initiate the array using np.empty rather than np.repeat(np.nans) 2. Add solid checking for unusual cases 3. Reform the code using proper Python guideline 4. Ditch the usage of pandas for sampling
1 parent 65660da commit 3540abb

2 files changed

Lines changed: 82 additions & 116 deletions

File tree

dabest/_stats_tools/confint_2group_diff.py

Lines changed: 41 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -155,74 +155,57 @@ def compute_delta2_bootstrapped_diff(
155155
"""
156156

157157
rng = RandomState(PCG64(random_seed))
158-
x1_len = len(x1)
159-
x2_len = len(x2)
160-
x3_len = len(x3)
161-
x4_len = len(x4)
162-
out_delta_g = np.repeat(np.nan, resamples)
163-
deltadelta = np.repeat(np.nan, resamples)
164-
165-
n_a1_b1, n_a2_b1, n_a1_b2, n_a2_b2 = x1_len, x2_len, x3_len, x4_len
166-
s_a1_b1, s_a2_b1, s_a1_b2, s_a2_b2 = np.std(x1), np.std(x2), np.std(x3), np.std(x4)
167-
168-
sd_numerator = (
169-
(n_a2_b1 - 1) * s_a2_b1**2
170-
+ (n_a1_b1 - 1) * s_a1_b1**2
171-
+ (n_a2_b2 - 1) * s_a2_b2**2
172-
+ (n_a1_b2 - 1) * s_a1_b2**2
173-
)
174-
sd_denominator = (n_a2_b1 - 1) + (n_a1_b1 - 1) + (n_a2_b2 - 1) + (n_a1_b2 - 1)
158+
159+
x1, x2, x3, x4 = map(np.asarray, [x1, x2, x3, x4])
160+
161+
# Calculating pooled sample standard deviation
162+
stds = [np.std(x) for x in [x1, x2, x3, x4]]
163+
ns = [len(x) for x in [x1, x2, x3, x4]]
164+
165+
sd_numerator = sum((n - 1) * s**2 for n, s in zip(ns, stds))
166+
sd_denominator = sum(n - 1 for n in ns)
167+
168+
# Avoid division by zero
169+
if sd_denominator == 0:
170+
raise ValueError("Insufficient data to compute pooled standard deviation.")
171+
175172
pooled_sample_sd = np.sqrt(sd_numerator / sd_denominator)
176173

177-
for i in range(int(resamples)):
174+
# Ensure pooled_sample_sd is not NaN or zero (to avoid division by zero later)
175+
if np.isnan(pooled_sample_sd) or pooled_sample_sd == 0:
176+
raise ValueError("Pooled sample standard deviation is NaN or zero.")
177+
178+
out_delta_g = np.empty(resamples)
179+
deltadelta = np.empty(resamples)
180+
181+
# Bootstrapping
182+
for i in range(resamples):
183+
# Paired or unpaired resampling
178184
if is_paired:
179-
if (x1_len != x2_len) or (x3_len != x4_len):
180-
raise ValueError("The two arrays do not have the same length.")
181-
df_paired_1 = pd.DataFrame(
182-
{
183-
"value": np.concatenate([x1, x3]),
184-
"array_id": np.repeat(["x1", "x3"], [x1_len, x3_len]),
185-
}
186-
)
187-
df_paired_2 = pd.DataFrame(
188-
{
189-
"value": np.concatenate([x2, x4]),
190-
"array_id": np.repeat(["x2", "x4"], [x1_len, x3_len]),
191-
}
192-
)
193-
x_sample_index = rng.choice(
194-
len(df_paired_1), len(df_paired_1), replace=True
195-
)
196-
x_sample_1 = df_paired_1.loc[x_sample_index]
197-
x_sample_2 = df_paired_2.loc[x_sample_index]
198-
x1_sample = x_sample_1[x_sample_1["array_id"] == "x1"]["value"]
199-
x2_sample = x_sample_2[x_sample_2["array_id"] == "x2"]["value"]
200-
x3_sample = x_sample_1[x_sample_1["array_id"] == "x3"]["value"]
201-
x4_sample = x_sample_2[x_sample_2["array_id"] == "x4"]["value"]
185+
if len(x1) != len(x2) or len(x3) != len(x4):
186+
raise ValueError("Each control group must have the same length as its corresponding test group in paired analysis.")
187+
indices_1 = rng.choice(len(x1), len(x1), replace=True)
188+
indices_2 = rng.choice(len(x3), len(x3), replace=True)
189+
190+
x1_sample, x2_sample = x1[indices_1], x2[indices_1]
191+
x3_sample, x4_sample = x3[indices_2], x4[indices_2]
202192
else:
203-
df = pd.DataFrame(
204-
{
205-
"value": np.concatenate([x1, x2, x3, x4]),
206-
"array_id": np.repeat(
207-
["x1", "x2", "x3", "x4"], [x1_len, x2_len, x3_len, x4_len]
208-
),
209-
}
210-
)
211-
x_sample_index = rng.choice(len(df), len(df), replace=True)
212-
x_sample = df.loc[x_sample_index]
213-
x1_sample = x_sample[x_sample["array_id"] == "x1"]["value"]
214-
x2_sample = x_sample[x_sample["array_id"] == "x2"]["value"]
215-
x3_sample = x_sample[x_sample["array_id"] == "x3"]["value"]
216-
x4_sample = x_sample[x_sample["array_id"] == "x4"]["value"]
193+
x1_sample = rng.choice(x1, len(x1), replace=True)
194+
x2_sample = rng.choice(x2, len(x2), replace=True)
195+
x3_sample = rng.choice(x3, len(x3), replace=True)
196+
x4_sample = rng.choice(x4, len(x4), replace=True)
217197

198+
# Calculating deltas
218199
delta_1 = np.mean(x2_sample) - np.mean(x1_sample)
219200
delta_2 = np.mean(x4_sample) - np.mean(x3_sample)
220201
delta_delta = delta_2 - delta_1
202+
221203
deltadelta[i] = delta_delta
222204
out_delta_g[i] = delta_delta / pooled_sample_sd
223-
delta_g = (
224-
(np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1))
225-
) / pooled_sample_sd
205+
206+
# Empirical delta_g calculation
207+
delta_g = ((np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1))) / pooled_sample_sd
208+
226209
return out_delta_g, delta_g, deltadelta
227210

228211

nbs/API/confint_2group_diff.ipynb

Lines changed: 41 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -212,74 +212,57 @@
212212
" \"\"\"\n",
213213
"\n",
214214
" rng = RandomState(PCG64(random_seed))\n",
215-
" x1_len = len(x1)\n",
216-
" x2_len = len(x2)\n",
217-
" x3_len = len(x3)\n",
218-
" x4_len = len(x4)\n",
219-
" out_delta_g = np.repeat(np.nan, resamples)\n",
220-
" deltadelta = np.repeat(np.nan, resamples)\n",
221-
"\n",
222-
" n_a1_b1, n_a2_b1, n_a1_b2, n_a2_b2 = x1_len, x2_len, x3_len, x4_len\n",
223-
" s_a1_b1, s_a2_b1, s_a1_b2, s_a2_b2 = np.std(x1), np.std(x2), np.std(x3), np.std(x4)\n",
224-
"\n",
225-
" sd_numerator = (\n",
226-
" (n_a2_b1 - 1) * s_a2_b1**2\n",
227-
" + (n_a1_b1 - 1) * s_a1_b1**2\n",
228-
" + (n_a2_b2 - 1) * s_a2_b2**2\n",
229-
" + (n_a1_b2 - 1) * s_a1_b2**2\n",
230-
" )\n",
231-
" sd_denominator = (n_a2_b1 - 1) + (n_a1_b1 - 1) + (n_a2_b2 - 1) + (n_a1_b2 - 1)\n",
215+
"\n",
216+
" x1, x2, x3, x4 = map(np.asarray, [x1, x2, x3, x4])\n",
217+
"\n",
218+
" # Calculating pooled sample standard deviation\n",
219+
" stds = [np.std(x) for x in [x1, x2, x3, x4]]\n",
220+
" ns = [len(x) for x in [x1, x2, x3, x4]]\n",
221+
"\n",
222+
" sd_numerator = sum((n - 1) * s**2 for n, s in zip(ns, stds))\n",
223+
" sd_denominator = sum(n - 1 for n in ns)\n",
224+
"\n",
225+
" # Avoid division by zero\n",
226+
" if sd_denominator == 0:\n",
227+
" raise ValueError(\"Insufficient data to compute pooled standard deviation.\")\n",
228+
"\n",
232229
" pooled_sample_sd = np.sqrt(sd_numerator / sd_denominator)\n",
233230
"\n",
234-
" for i in range(int(resamples)):\n",
231+
" # Ensure pooled_sample_sd is not NaN or zero (to avoid division by zero later)\n",
232+
" if np.isnan(pooled_sample_sd) or pooled_sample_sd == 0:\n",
233+
" raise ValueError(\"Pooled sample standard deviation is NaN or zero.\")\n",
234+
"\n",
235+
" out_delta_g = np.empty(resamples)\n",
236+
" deltadelta = np.empty(resamples)\n",
237+
"\n",
238+
" # Bootstrapping\n",
239+
" for i in range(resamples):\n",
240+
" # Paired or unpaired resampling\n",
235241
" if is_paired:\n",
236-
" if (x1_len != x2_len) or (x3_len != x4_len):\n",
237-
" raise ValueError(\"The two arrays do not have the same length.\")\n",
238-
" df_paired_1 = pd.DataFrame(\n",
239-
" {\n",
240-
" \"value\": np.concatenate([x1, x3]),\n",
241-
" \"array_id\": np.repeat([\"x1\", \"x3\"], [x1_len, x3_len]),\n",
242-
" }\n",
243-
" )\n",
244-
" df_paired_2 = pd.DataFrame(\n",
245-
" {\n",
246-
" \"value\": np.concatenate([x2, x4]),\n",
247-
" \"array_id\": np.repeat([\"x2\", \"x4\"], [x1_len, x3_len]),\n",
248-
" }\n",
249-
" )\n",
250-
" x_sample_index = rng.choice(\n",
251-
" len(df_paired_1), len(df_paired_1), replace=True\n",
252-
" )\n",
253-
" x_sample_1 = df_paired_1.loc[x_sample_index]\n",
254-
" x_sample_2 = df_paired_2.loc[x_sample_index]\n",
255-
" x1_sample = x_sample_1[x_sample_1[\"array_id\"] == \"x1\"][\"value\"]\n",
256-
" x2_sample = x_sample_2[x_sample_2[\"array_id\"] == \"x2\"][\"value\"]\n",
257-
" x3_sample = x_sample_1[x_sample_1[\"array_id\"] == \"x3\"][\"value\"]\n",
258-
" x4_sample = x_sample_2[x_sample_2[\"array_id\"] == \"x4\"][\"value\"]\n",
242+
" if len(x1) != len(x2) or len(x3) != len(x4):\n",
243+
" raise ValueError(\"Each control group must have the same length as its corresponding test group in paired analysis.\")\n",
244+
" indices_1 = rng.choice(len(x1), len(x1), replace=True)\n",
245+
" indices_2 = rng.choice(len(x3), len(x3), replace=True)\n",
246+
"\n",
247+
" x1_sample, x2_sample = x1[indices_1], x2[indices_1]\n",
248+
" x3_sample, x4_sample = x3[indices_2], x4[indices_2]\n",
259249
" else:\n",
260-
" df = pd.DataFrame(\n",
261-
" {\n",
262-
" \"value\": np.concatenate([x1, x2, x3, x4]),\n",
263-
" \"array_id\": np.repeat(\n",
264-
" [\"x1\", \"x2\", \"x3\", \"x4\"], [x1_len, x2_len, x3_len, x4_len]\n",
265-
" ),\n",
266-
" }\n",
267-
" )\n",
268-
" x_sample_index = rng.choice(len(df), len(df), replace=True)\n",
269-
" x_sample = df.loc[x_sample_index]\n",
270-
" x1_sample = x_sample[x_sample[\"array_id\"] == \"x1\"][\"value\"]\n",
271-
" x2_sample = x_sample[x_sample[\"array_id\"] == \"x2\"][\"value\"]\n",
272-
" x3_sample = x_sample[x_sample[\"array_id\"] == \"x3\"][\"value\"]\n",
273-
" x4_sample = x_sample[x_sample[\"array_id\"] == \"x4\"][\"value\"]\n",
250+
" x1_sample = rng.choice(x1, len(x1), replace=True)\n",
251+
" x2_sample = rng.choice(x2, len(x2), replace=True)\n",
252+
" x3_sample = rng.choice(x3, len(x3), replace=True)\n",
253+
" x4_sample = rng.choice(x4, len(x4), replace=True)\n",
274254
"\n",
255+
" # Calculating deltas\n",
275256
" delta_1 = np.mean(x2_sample) - np.mean(x1_sample)\n",
276257
" delta_2 = np.mean(x4_sample) - np.mean(x3_sample)\n",
277258
" delta_delta = delta_2 - delta_1\n",
259+
"\n",
278260
" deltadelta[i] = delta_delta\n",
279261
" out_delta_g[i] = delta_delta / pooled_sample_sd\n",
280-
" delta_g = (\n",
281-
" (np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1))\n",
282-
" ) / pooled_sample_sd\n",
262+
"\n",
263+
" # Empirical delta_g calculation\n",
264+
" delta_g = ((np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1))) / pooled_sample_sd\n",
265+
"\n",
283266
" return out_delta_g, delta_g, deltadelta\n",
284267
"\n",
285268
"\n",

0 commit comments

Comments
 (0)