Commit eb74963
Fix UpliftRandomForest predict shape mismatch with multiple treatments (#569)
Bootstrap sampling can exclude entire treatment groups from a tree's
training data, causing individual trees to produce prediction arrays
of different widths. When summing predictions across trees, this
causes a ValueError for shape mismatch.
Added _align_tree_predict() that maps each tree's predictions to the
forest-level class ordering, filling zeros for missing treatment
groups. This is a module-level function (not a closure) so it works
with joblib's parallel pickling.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>1 parent d688549 commit eb74963
1 file changed
Lines changed: 21 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
59 | 59 | | |
60 | 60 | | |
61 | 61 | | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
62 | 81 | | |
63 | 82 | | |
64 | 83 | | |
| |||
2549 | 2568 | | |
2550 | 2569 | | |
2551 | 2570 | | |
2552 | | - | |
| 2571 | + | |
2553 | 2572 | | |
2554 | 2573 | | |
2555 | | - | |
| 2574 | + | |
2556 | 2575 | | |
2557 | 2576 | | |
2558 | 2577 | | |
| |||
0 commit comments