Skip to content

Commit d169b1e

Browse files
jeongyoonleeclaude
andcommitted
Address review: precompute class mapping once, improve test robustness
- Build class_to_forest_idx dict once in predict() instead of per tree - Use model.n_jobs instead of parallel_backend for parallel test - Assert that sparse-group condition actually occurred in test Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f94d9d5 commit d169b1e

2 files changed

Lines changed: 19 additions & 6 deletions

File tree

causalml/inference/tree/uplift.pyx

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,20 @@ cdef extern from "math.h":
6060
double sqrt(double x) nogil
6161

6262

63-
def _align_tree_predict(tree, X, forest_classes):
63+
def _align_tree_predict(tree, X, forest_classes, class_to_forest_idx):
6464
"""Predict with a single tree and align output to the forest's classes.
6565
6666
When a bootstrap sample excludes some treatment groups, the tree's
6767
classes_ will be a subset of the forest's classes_. This function
6868
maps the tree's predictions to the forest-level class ordering.
69+
70+
Args:
71+
class_to_forest_idx: Precomputed {class_label: forest_index} mapping.
6972
"""
7073
raw = tree.predict(X=X)
7174
if len(tree.classes_) == len(forest_classes):
7275
return raw
7376
aligned = np.zeros((raw.shape[0], len(forest_classes)), dtype=raw.dtype)
74-
class_to_forest_idx = {cls: idx for idx, cls in enumerate(forest_classes)}
7577
for tree_idx, cls in enumerate(tree.classes_):
7678
forest_idx = class_to_forest_idx.get(cls)
7779
if forest_idx is not None:
@@ -2565,14 +2567,15 @@ class UpliftRandomForestClassifier:
25652567
25662568
'''
25672569
# Make predictions with all trees and take the average
2570+
class_to_forest_idx = {cls: idx for idx, cls in enumerate(self.classes_)}
25682571

25692572
if self.n_jobs != 1:
25702573
y_pred_ensemble = sum(
25712574
Parallel(n_jobs=self.n_jobs, prefer=self.joblib_prefer)
2572-
(delayed(_align_tree_predict)(tree, X, self.classes_) for tree in self.uplift_forest)
2575+
(delayed(_align_tree_predict)(tree, X, self.classes_, class_to_forest_idx) for tree in self.uplift_forest)
25732576
) / len(self.uplift_forest)
25742577
else:
2575-
y_pred_ensemble = sum([_align_tree_predict(tree, X, self.classes_) for tree in self.uplift_forest]) / len(self.uplift_forest)
2578+
y_pred_ensemble = sum([_align_tree_predict(tree, X, self.classes_, class_to_forest_idx) for tree in self.uplift_forest]) / len(self.uplift_forest)
25762579

25772580
# Summarize results into dataframe
25782581
df_res = pd.DataFrame(y_pred_ensemble, columns=self.classes_)

tests/test_uplift_trees.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,13 +310,23 @@ def test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups():
310310
model = UpliftRandomForestClassifier(
311311
control_name=CONTROL_NAME,
312312
n_estimators=10,
313+
n_jobs=2,
313314
min_samples_leaf=1,
314315
min_samples_treatment=0,
315316
random_state=RANDOM_SEED,
316317
)
317318
model.fit(X, treatment=treatment, y=y)
318319

320+
# Verify that at least one tree was fit without some treatment groups
321+
assert any(
322+
len(tree.classes_) < len(model.classes_) for tree in model.uplift_forest
323+
), (
324+
"Test setup failed to produce any trees missing treatment groups; "
325+
"adjust seed or sampling parameters to exercise sparse-group behavior."
326+
)
327+
319328
# Single-threaded
329+
model.n_jobs = 1
320330
preds = model.predict(X)
321331
assert preds.shape == (
322332
n,
@@ -325,7 +335,7 @@ def test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups():
325335
assert not np.any(np.isnan(preds)), "Predictions contain NaN"
326336

327337
# Parallel
328-
with parallel_backend("threading", n_jobs=2):
329-
preds_par = model.predict(X)
338+
model.n_jobs = 2
339+
preds_par = model.predict(X)
330340
assert preds_par.shape == preds.shape
331341
assert np.allclose(preds, preds_par)

0 commit comments

Comments
 (0)