Skip to content

Commit c043043

Browse files
jeongyoonleeclaude
andauthored
Fix UpliftRandomForest predict shape mismatch with multiple treatments (#569) (#884)
* Fix UpliftRandomForest predict shape mismatch with multiple treatments (#569) Bootstrap sampling can exclude entire treatment groups from a tree's training data, causing individual trees to produce prediction arrays of different widths. When summing predictions across trees, this causes a ValueError for shape mismatch. Added _align_tree_predict() that maps each tree's predictions to the forest-level class ordering, filling zeros for missing treatment groups. This is a module-level function (not a closure) so it works with joblib's parallel pickling. * Address review: use dict lookup, preserve dtype, add regression test - Use dict for O(1) class-to-index mapping instead of repeated list scans - Preserve dtype with dtype=raw.dtype in aligned array - Add test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups * Address review: precompute class mapping once, improve test robustness - Build class_to_forest_idx dict once in predict() instead of per tree - Use model.n_jobs instead of parallel_backend for parallel test - Assert that sparse-group condition actually occurred in test * Make sparse-group test deterministic with 1-sample minority groups With only 1 sample per minority treatment group out of 102 total, bootstrap sampling will miss them in most trees, making the test deterministic regardless of seed or CI environment. --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 490e35c commit c043043

2 files changed

Lines changed: 74 additions & 2 deletions

File tree

causalml/inference/tree/uplift.pyx

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,28 @@ cdef extern from "math.h":
6060
double fabs(double x) nogil
6161
double sqrt(double x) nogil
6262

63+
64+
def _align_tree_predict(tree, X, forest_classes, class_to_forest_idx):
65+
"""Predict with a single tree and align output to the forest's classes.
66+
67+
When a bootstrap sample excludes some treatment groups, the tree's
68+
classes_ will be a subset of the forest's classes_. This function
69+
maps the tree's predictions to the forest-level class ordering.
70+
71+
Args:
72+
class_to_forest_idx: Precomputed {class_label: forest_index} mapping.
73+
"""
74+
raw = tree.predict(X=X)
75+
if len(tree.classes_) == len(forest_classes):
76+
return raw
77+
aligned = np.zeros((raw.shape[0], len(forest_classes)), dtype=raw.dtype)
78+
for tree_idx, cls in enumerate(tree.classes_):
79+
forest_idx = class_to_forest_idx.get(cls)
80+
if forest_idx is not None:
81+
aligned[:, forest_idx] = raw[:, tree_idx]
82+
return aligned
83+
84+
6385
@cython.cfunc
6486
def kl_divergence(pk: cython.float, qk: cython.float) -> cython.float:
6587
'''
@@ -2692,14 +2714,15 @@ class UpliftRandomForestClassifier:
26922714
26932715
'''
26942716
# Make predictions with all trees and take the average
2717+
class_to_forest_idx = {cls: idx for idx, cls in enumerate(self.classes_)}
26952718

26962719
if self.n_jobs != 1:
26972720
y_pred_ensemble = sum(
26982721
Parallel(n_jobs=self.n_jobs, prefer=self.joblib_prefer)
2699-
(delayed(tree.predict)(X=X) for tree in self.uplift_forest)
2722+
(delayed(_align_tree_predict)(tree, X, self.classes_, class_to_forest_idx) for tree in self.uplift_forest)
27002723
) / len(self.uplift_forest)
27012724
else:
2702-
y_pred_ensemble = sum([tree.predict(X=X) for tree in self.uplift_forest]) / len(self.uplift_forest)
2725+
y_pred_ensemble = sum([_align_tree_predict(tree, X, self.classes_, class_to_forest_idx) for tree in self.uplift_forest]) / len(self.uplift_forest)
27032726

27042727
# Summarize results into dataframe
27052728
df_res = pd.DataFrame(y_pred_ensemble, columns=self.classes_)

tests/test_uplift_trees.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,52 @@ def test_uplift_tree_pvalue_no_nan_with_sparse_groups():
389389
assert not np.any(
390390
np.isnan(preds)
391391
), "Predictions contain NaN (likely from NaN p-values)"
392+
393+
394+
def test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups():
395+
"""Test that UpliftRandomForestClassifier.predict() returns correct shape
396+
when bootstrap sampling causes some trees to miss treatment groups (#569)."""
397+
np.random.seed(RANDOM_SEED)
398+
n = 102
399+
X = np.random.randn(n, 3)
400+
# Only 1 sample per minority treatment group guarantees that bootstrap
401+
# sampling (with replacement, n draws from n) will miss them in some trees.
402+
# P(group included) = 1 - (1 - 1/n)^n ≈ 1 - 1/e ≈ 0.63 per tree,
403+
# so with 10 trees the chance ALL include both groups is ~0.63^20 ≈ 0.01%.
404+
treatment = np.array(
405+
[CONTROL_NAME] * 100 + [TREATMENT_NAMES[1]] * 1 + [TREATMENT_NAMES[2]] * 1
406+
)
407+
y = np.random.randint(0, 2, n)
408+
409+
model = UpliftRandomForestClassifier(
410+
control_name=CONTROL_NAME,
411+
n_estimators=10,
412+
n_jobs=2,
413+
min_samples_leaf=1,
414+
min_samples_treatment=0,
415+
random_state=RANDOM_SEED,
416+
)
417+
model.fit(X, treatment=treatment, y=y)
418+
419+
# Verify that at least one tree was fit without some treatment groups
420+
assert any(
421+
len(tree.classes_) < len(model.classes_) for tree in model.uplift_forest
422+
), (
423+
"Test setup failed to produce any trees missing treatment groups; "
424+
"adjust seed or sampling parameters to exercise sparse-group behavior."
425+
)
426+
427+
# Single-threaded
428+
model.n_jobs = 1
429+
preds = model.predict(X)
430+
assert preds.shape == (
431+
n,
432+
len(model.classes_) - 1,
433+
), f"Expected shape ({n}, {len(model.classes_) - 1}), got {preds.shape}"
434+
assert not np.any(np.isnan(preds)), "Predictions contain NaN"
435+
436+
# Parallel
437+
model.n_jobs = 2
438+
preds_par = model.predict(X)
439+
assert preds_par.shape == preds.shape
440+
assert np.allclose(preds, preds_par)

0 commit comments

Comments
 (0)