Skip to content

Commit eb74963

Browse files
jeongyoonleeclaude
andcommitted
Fix UpliftRandomForest predict shape mismatch with multiple treatments (#569)
Bootstrap sampling can exclude entire treatment groups from a tree's training data, causing individual trees to produce prediction arrays of different widths. When summing predictions across trees, this causes a ValueError for shape mismatch. Added _align_tree_predict() that maps each tree's predictions to the forest-level class ordering, filling zeros for missing treatment groups. This is a module-level function (not a closure) so it works with joblib's parallel pickling. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d688549 commit eb74963

1 file changed

Lines changed: 21 additions & 2 deletions

File tree

causalml/inference/tree/uplift.pyx

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,25 @@ cdef extern from "math.h":
5959
double fabs(double x) nogil
6060
double sqrt(double x) nogil
6161

62+
63+
def _align_tree_predict(tree, X, forest_classes):
64+
"""Predict with a single tree and align output to the forest's classes.
65+
66+
When a bootstrap sample excludes some treatment groups, the tree's
67+
classes_ will be a subset of the forest's classes_. This function
68+
maps the tree's predictions to the forest-level class ordering.
69+
"""
70+
raw = tree.predict(X=X)
71+
if len(tree.classes_) == len(forest_classes):
72+
return raw
73+
aligned = np.zeros((raw.shape[0], len(forest_classes)))
74+
for tree_idx, cls in enumerate(tree.classes_):
75+
if cls in forest_classes:
76+
forest_idx = forest_classes.index(cls)
77+
aligned[:, forest_idx] = raw[:, tree_idx]
78+
return aligned
79+
80+
6281
@cython.cfunc
6382
def kl_divergence(pk: cython.float, qk: cython.float) -> cython.float:
6483
'''
@@ -2549,10 +2568,10 @@ class UpliftRandomForestClassifier:
25492568
if self.n_jobs != 1:
25502569
y_pred_ensemble = sum(
25512570
Parallel(n_jobs=self.n_jobs, prefer=self.joblib_prefer)
2552-
(delayed(tree.predict)(X=X) for tree in self.uplift_forest)
2571+
(delayed(_align_tree_predict)(tree, X, self.classes_) for tree in self.uplift_forest)
25532572
) / len(self.uplift_forest)
25542573
else:
2555-
y_pred_ensemble = sum([tree.predict(X=X) for tree in self.uplift_forest]) / len(self.uplift_forest)
2574+
y_pred_ensemble = sum([_align_tree_predict(tree, X, self.classes_) for tree in self.uplift_forest]) / len(self.uplift_forest)
25562575

25572576
# Summarize results into dataframe
25582577
df_res = pd.DataFrame(y_pred_ensemble, columns=self.classes_)

0 commit comments

Comments
 (0)