Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions causalml/inference/tree/uplift.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,28 @@ cdef extern from "math.h":
double fabs(double x) nogil
double sqrt(double x) nogil


def _align_tree_predict(tree, X, forest_classes, class_to_forest_idx):
"""Predict with a single tree and align output to the forest's classes.

When a bootstrap sample excludes some treatment groups, the tree's
classes_ will be a subset of the forest's classes_. This function
maps the tree's predictions to the forest-level class ordering.

Args:
class_to_forest_idx: Precomputed {class_label: forest_index} mapping.
"""
raw = tree.predict(X=X)
if len(tree.classes_) == len(forest_classes):
return raw
aligned = np.zeros((raw.shape[0], len(forest_classes)), dtype=raw.dtype)
for tree_idx, cls in enumerate(tree.classes_):
forest_idx = class_to_forest_idx.get(cls)
if forest_idx is not None:
aligned[:, forest_idx] = raw[:, tree_idx]
Comment on lines +77 to +81
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class_to_forest_idx is rebuilt for every tree prediction call. Consider constructing this mapping once in UpliftRandomForestClassifier.predict() and passing it (or passing precomputed forest indices for tree.classes_) to reduce overhead when n_estimators or number of treatment groups is large.

Copilot uses AI. Check for mistakes.
return aligned


@cython.cfunc
def kl_divergence(pk: cython.float, qk: cython.float) -> cython.float:
'''
Expand Down Expand Up @@ -2692,14 +2714,15 @@ class UpliftRandomForestClassifier:

'''
# Make predictions with all trees and take the average
class_to_forest_idx = {cls: idx for idx, cls in enumerate(self.classes_)}

if self.n_jobs != 1:
y_pred_ensemble = sum(
Parallel(n_jobs=self.n_jobs, prefer=self.joblib_prefer)
(delayed(tree.predict)(X=X) for tree in self.uplift_forest)
(delayed(_align_tree_predict)(tree, X, self.classes_, class_to_forest_idx) for tree in self.uplift_forest)
) / len(self.uplift_forest)
else:
y_pred_ensemble = sum([tree.predict(X=X) for tree in self.uplift_forest]) / len(self.uplift_forest)
y_pred_ensemble = sum([_align_tree_predict(tree, X, self.classes_, class_to_forest_idx) for tree in self.uplift_forest]) / len(self.uplift_forest)

# Summarize results into dataframe
df_res = pd.DataFrame(y_pred_ensemble, columns=self.classes_)
Expand Down
49 changes: 49 additions & 0 deletions tests/test_uplift_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,52 @@ def test_uplift_tree_pvalue_no_nan_with_sparse_groups():
assert not np.any(
np.isnan(preds)
), "Predictions contain NaN (likely from NaN p-values)"


def test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups():
"""Test that UpliftRandomForestClassifier.predict() returns correct shape
when bootstrap sampling causes some trees to miss treatment groups (#569)."""
np.random.seed(RANDOM_SEED)
n = 102
X = np.random.randn(n, 3)
# Only 1 sample per minority treatment group guarantees that bootstrap
# sampling (with replacement, n draws from n) will miss them in some trees.
# P(group included) = 1 - (1 - 1/n)^n ≈ 1 - 1/e ≈ 0.63 per tree,
# so with 10 trees the chance ALL include both groups is ~0.63^20 ≈ 0.01%.
treatment = np.array(
[CONTROL_NAME] * 100 + [TREATMENT_NAMES[1]] * 1 + [TREATMENT_NAMES[2]] * 1
)
y = np.random.randint(0, 2, n)

model = UpliftRandomForestClassifier(
control_name=CONTROL_NAME,
n_estimators=10,
n_jobs=2,
min_samples_leaf=1,
min_samples_treatment=0,
random_state=RANDOM_SEED,
)
model.fit(X, treatment=treatment, y=y)

Comment thread
jeongyoonlee marked this conversation as resolved.
# Verify that at least one tree was fit without some treatment groups
assert any(
len(tree.classes_) < len(model.classes_) for tree in model.uplift_forest
), (
"Test setup failed to produce any trees missing treatment groups; "
"adjust seed or sampling parameters to exercise sparse-group behavior."
)

# Single-threaded
model.n_jobs = 1
preds = model.predict(X)
assert preds.shape == (
n,
len(model.classes_) - 1,
), f"Expected shape ({n}, {len(model.classes_) - 1}), got {preds.shape}"
assert not np.any(np.isnan(preds)), "Predictions contain NaN"

# Parallel
model.n_jobs = 2
preds_par = model.predict(X)
assert preds_par.shape == preds.shape
Comment thread
jeongyoonlee marked this conversation as resolved.
assert np.allclose(preds, preds_par)