Commit c043043
* Fix UpliftRandomForest predict shape mismatch with multiple treatments (#569)
Bootstrap sampling can exclude entire treatment groups from a tree's
training data, causing individual trees to produce prediction arrays
of different widths. When summing predictions across trees, this
causes a ValueError for shape mismatch.
Added _align_tree_predict() that maps each tree's predictions to the
forest-level class ordering, filling zeros for missing treatment
groups. This is a module-level function (not a closure) so it works
with joblib's parallel pickling.
* Address review: use dict lookup, preserve dtype, add regression test
- Use dict for O(1) class-to-index mapping instead of repeated list scans
- Preserve dtype with dtype=raw.dtype in aligned array
- Add test_UpliftRandomForestClassifier_predict_shape_with_sparse_groups
* Address review: precompute class mapping once, improve test robustness
- Build class_to_forest_idx dict once in predict() instead of per tree
- Use model.n_jobs instead of parallel_backend for parallel test
- Assert that sparse-group condition actually occurred in test
* Make sparse-group test deterministic with 1-sample minority groups
With only 1 sample per minority treatment group out of 102 total,
bootstrap sampling will miss them in most trees, making the test
deterministic regardless of seed or CI environment.
---------
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 490e35c commit c043043
2 files changed
Lines changed: 74 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
60 | 60 | | |
61 | 61 | | |
62 | 62 | | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
63 | 85 | | |
64 | 86 | | |
65 | 87 | | |
| |||
2692 | 2714 | | |
2693 | 2715 | | |
2694 | 2716 | | |
| 2717 | + | |
2695 | 2718 | | |
2696 | 2719 | | |
2697 | 2720 | | |
2698 | 2721 | | |
2699 | | - | |
| 2722 | + | |
2700 | 2723 | | |
2701 | 2724 | | |
2702 | | - | |
| 2725 | + | |
2703 | 2726 | | |
2704 | 2727 | | |
2705 | 2728 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
389 | 389 | | |
390 | 390 | | |
391 | 391 | | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
0 commit comments