Skip to content

Commit f94d9d5

Browse files
jeongyoonleeclaude
andcommitted
Address review: use dict lookup, preserve dtype, add regression test
- Use dict for O(1) class-to-index mapping instead of repeated list scans - Preserve dtype with dtype=raw.dtype in aligned array - Add test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent eb74963 commit f94d9d5

2 files changed

Lines changed: 40 additions & 3 deletions

File tree

causalml/inference/tree/uplift.pyx

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,11 @@ def _align_tree_predict(tree, X, forest_classes):
7070
raw = tree.predict(X=X)
7171
if len(tree.classes_) == len(forest_classes):
7272
return raw
73-
aligned = np.zeros((raw.shape[0], len(forest_classes)))
73+
aligned = np.zeros((raw.shape[0], len(forest_classes)), dtype=raw.dtype)
74+
class_to_forest_idx = {cls: idx for idx, cls in enumerate(forest_classes)}
7475
for tree_idx, cls in enumerate(tree.classes_):
75-
if cls in forest_classes:
76-
forest_idx = forest_classes.index(cls)
76+
forest_idx = class_to_forest_idx.get(cls)
77+
if forest_idx is not None:
7778
aligned[:, forest_idx] = raw[:, tree_idx]
7879
return aligned
7980

tests/test_uplift_trees.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,39 @@ def test_uplift_tree_visualization():
293293
# Plot uplift tree
294294
graph = uplift_tree_plot(uplift_model.fitted_uplift_tree, x_names)
295295
graph.create_png()
296+
297+
298+
def test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups():
299+
"""Test that UpliftRandomForestClassifier.predict() returns correct shape
300+
when bootstrap sampling causes some trees to miss treatment groups (#569)."""
301+
np.random.seed(RANDOM_SEED)
302+
n = 60
303+
X = np.random.randn(n, 3)
304+
# Very few samples in treatment groups so bootstraps are likely to miss some
305+
treatment = np.array(
306+
[CONTROL_NAME] * 50 + [TREATMENT_NAMES[1]] * 5 + [TREATMENT_NAMES[2]] * 5
307+
)
308+
y = np.random.randint(0, 2, n)
309+
310+
model = UpliftRandomForestClassifier(
311+
control_name=CONTROL_NAME,
312+
n_estimators=10,
313+
min_samples_leaf=1,
314+
min_samples_treatment=0,
315+
random_state=RANDOM_SEED,
316+
)
317+
model.fit(X, treatment=treatment, y=y)
318+
319+
# Single-threaded
320+
preds = model.predict(X)
321+
assert preds.shape == (
322+
n,
323+
len(model.classes_) - 1,
324+
), f"Expected shape ({n}, {len(model.classes_) - 1}), got {preds.shape}"
325+
assert not np.any(np.isnan(preds)), "Predictions contain NaN"
326+
327+
# Parallel
328+
with parallel_backend("threading", n_jobs=2):
329+
preds_par = model.predict(X)
330+
assert preds_par.shape == preds.shape
331+
assert np.allclose(preds, preds_par)

0 commit comments

Comments
 (0)