Skip to content

Commit 009a109

Browse files
jeongyoonleeclaude
andcommitted
Address review: precompute class mapping once, improve test robustness
- Build class_to_forest_idx dict once in predict() instead of per tree - Use model.n_jobs instead of parallel_backend for parallel test - Assert that sparse-group condition actually occurred in test Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 298e7d9 commit 009a109

2 files changed

Lines changed: 19 additions & 6 deletions

File tree

causalml/inference/tree/uplift.pyx

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,20 @@ cdef extern from "math.h":
6161
double sqrt(double x) nogil
6262

6363

64-
def _align_tree_predict(tree, X, forest_classes):
64+
def _align_tree_predict(tree, X, forest_classes, class_to_forest_idx):
6565
"""Predict with a single tree and align output to the forest's classes.
6666
6767
When a bootstrap sample excludes some treatment groups, the tree's
6868
classes_ will be a subset of the forest's classes_. This function
6969
maps the tree's predictions to the forest-level class ordering.
70+
71+
Args:
72+
class_to_forest_idx: Precomputed {class_label: forest_index} mapping.
7073
"""
7174
raw = tree.predict(X=X)
7275
if len(tree.classes_) == len(forest_classes):
7376
return raw
7477
aligned = np.zeros((raw.shape[0], len(forest_classes)), dtype=raw.dtype)
75-
class_to_forest_idx = {cls: idx for idx, cls in enumerate(forest_classes)}
7678
for tree_idx, cls in enumerate(tree.classes_):
7779
forest_idx = class_to_forest_idx.get(cls)
7880
if forest_idx is not None:
@@ -2705,14 +2707,15 @@ class UpliftRandomForestClassifier:
27052707
27062708
'''
27072709
# Make predictions with all trees and take the average
2710+
class_to_forest_idx = {cls: idx for idx, cls in enumerate(self.classes_)}
27082711

27092712
if self.n_jobs != 1:
27102713
y_pred_ensemble = sum(
27112714
Parallel(n_jobs=self.n_jobs, prefer=self.joblib_prefer)
2712-
(delayed(_align_tree_predict)(tree, X, self.classes_) for tree in self.uplift_forest)
2715+
(delayed(_align_tree_predict)(tree, X, self.classes_, class_to_forest_idx) for tree in self.uplift_forest)
27132716
) / len(self.uplift_forest)
27142717
else:
2715-
y_pred_ensemble = sum([_align_tree_predict(tree, X, self.classes_) for tree in self.uplift_forest]) / len(self.uplift_forest)
2718+
y_pred_ensemble = sum([_align_tree_predict(tree, X, self.classes_, class_to_forest_idx) for tree in self.uplift_forest]) / len(self.uplift_forest)
27162719

27172720
# Summarize results into dataframe
27182721
df_res = pd.DataFrame(y_pred_ensemble, columns=self.classes_)

tests/test_uplift_trees.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,13 +382,23 @@ def test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups():
382382
model = UpliftRandomForestClassifier(
383383
control_name=CONTROL_NAME,
384384
n_estimators=10,
385+
n_jobs=2,
385386
min_samples_leaf=1,
386387
min_samples_treatment=0,
387388
random_state=RANDOM_SEED,
388389
)
389390
model.fit(X, treatment=treatment, y=y)
390391

392+
# Verify that at least one tree was fit without some treatment groups
393+
assert any(
394+
len(tree.classes_) < len(model.classes_) for tree in model.uplift_forest
395+
), (
396+
"Test setup failed to produce any trees missing treatment groups; "
397+
"adjust seed or sampling parameters to exercise sparse-group behavior."
398+
)
399+
391400
# Single-threaded
401+
model.n_jobs = 1
392402
preds = model.predict(X)
393403
assert preds.shape == (
394404
n,
@@ -397,7 +407,7 @@ def test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups():
397407
assert not np.any(np.isnan(preds)), "Predictions contain NaN"
398408

399409
# Parallel
400-
with parallel_backend("threading", n_jobs=2):
401-
preds_par = model.predict(X)
410+
model.n_jobs = 2
411+
preds_par = model.predict(X)
402412
assert preds_par.shape == preds.shape
403413
assert np.allclose(preds, preds_par)

0 commit comments

Comments
 (0)