feat(kernels): add fused masking + variable-length pack-and-pad op #182
feat(kernels): add fused masking + variable-length pack-and-pad op #182Chen-BUPT wants to merge 6 commits into
Conversation
…AM) (RL-Align#42) Measure TritonPackOp vs a PyTorch boolean-index baseline for pack latency, and the end-to-end peak VRAM of dense logp vs pack->logp to quantify the memory saving on sparse masks (the motivation behind RL-Align#42). Follows the existing benchmark_ratio_kl.py conventions (CUDA-event median timing, max_memory_allocated, CSV output, --smoke).
Pack hidden states before the vocab projection so the dense [B,S,V] logits are never materialized for masked-out tokens, which is the actual RL-Align#42 saving. Drop the fp32 upcast in selected-logp (use logits - logsumexp) so the dense path's peak memory reflects the logits, not an fp32 copy. Add --hidden-dim.
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (2)
🚧 Files skipped from review as they are similar to previous changes (2)
📝 WalkthroughWalkthroughAdds a fused masking and variable-length pack-and-pad operation implementing both a PyTorch-native autograd reference ( ChangesPack-and-pad fused op
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
rl_engine/kernels/ops/pytorch/packing/pack.py (1)
27-37: 📐 Maintainability & Code Quality | 🔵 Trivial | 💤 Low valueRemove the duplicate dimension check.
The
mask.dim() < 1guard at Lines 35-36 is dead code — it repeats the identical check at Lines 28-29, and the shape comparison in between never altersmask.dim().♻️ Proposed cleanup
def _validate(x: torch.Tensor, mask: torch.Tensor) -> None: if mask.dim() < 1: raise ValueError("mask must have at least one dimension.") if mask.shape != x.shape[: mask.dim()]: raise ValueError( f"mask shape {tuple(mask.shape)} must match the leading dims of " f"x.shape {tuple(x.shape)} (expected {tuple(x.shape[: mask.dim()])})." ) - if mask.dim() < 1: - raise ValueError("mask must have at least one dimension.")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@rl_engine/kernels/ops/pytorch/packing/pack.py` around lines 27 - 37, The _validate function contains a duplicate dimension check for mask.dim() < 1 that appears after the shape validation logic. Remove the second occurrence of the identical check (the one appearing after the shape comparison) since the intermediate validation logic does not modify mask.dim() and therefore makes the repeated check dead code. Keep only the initial dimension check and remove the redundant check that follows the shape validation.rl_engine/kernels/ops/triton/packing/pack.py (1)
99-102: 🚀 Performance & Scalability | 🔵 Trivial | ⚖️ Poor tradeoffGather grid launches a program for every source row, including inactive ones.
The grid is sized over
n_rows = B*S, and each program loadsdestand early-exits when the row is inactive (dest < 0). For the low-density packing this op targets (e.g. 0.05), the large majority of launched programs do no work. Launching over then_activerows via a packed→source inverse index would avoid the wasted launches on the hot path. This is the design tradeoff already acknowledged in the PR, so feel free to defer.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@rl_engine/kernels/ops/triton/packing/pack.py` around lines 99 - 102, The grid in the `_pack_gather_kernel` launch is currently sized over `n_rows` which includes all source rows, causing programs to be launched for inactive rows that do no work (they just early-exit when dest < 0). To optimize this for sparse packing scenarios, resize the grid to be sized over `n_active` instead of `n_rows`, and introduce a packed-to-source inverse index mapping that the kernel can use to look up which source rows are actually active. This eliminates wasted kernel launches on inactive rows while keeping the same kernel logic.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@benchmarks/benchmark_pack.py`:
- Around line 214-252: The broad except Exception block that catches all
exceptions and sets status to "blocked" masks real execution errors
(kernel/runtime/math failures) as if the candidate backend is simply
unavailable. Replace the generic except Exception handler with specific
exception handling that only catches exceptions that genuinely indicate
unavailability (such as ImportError for missing kernel_registry or RuntimeError
for unavailable backends), and allow other exceptions like runtime/math/kernel
errors to propagate or be handled separately so they surface as actual failures
rather than being silently masked as blocked candidates.
In `@tests/test_pack.py`:
- Around line 18-31: The TritonPackOp import on line 18 happens unconditionally
before the Triton availability check, causing test collection to fail in
environments without Triton. Move the TritonPackOp import into the try block
alongside the triton import, and add a fallback assignment in the except block
(such as TritonPackOp = None) so the name can be safely referenced elsewhere in
the code. This ensures tests are properly skipped by the requires_triton_cuda
marker instead of failing during collection.
---
Nitpick comments:
In `@rl_engine/kernels/ops/pytorch/packing/pack.py`:
- Around line 27-37: The _validate function contains a duplicate dimension check
for mask.dim() < 1 that appears after the shape validation logic. Remove the
second occurrence of the identical check (the one appearing after the shape
comparison) since the intermediate validation logic does not modify mask.dim()
and therefore makes the repeated check dead code. Keep only the initial
dimension check and remove the redundant check that follows the shape
validation.
In `@rl_engine/kernels/ops/triton/packing/pack.py`:
- Around line 99-102: The grid in the `_pack_gather_kernel` launch is currently
sized over `n_rows` which includes all source rows, causing programs to be
launched for inactive rows that do no work (they just early-exit when dest < 0).
To optimize this for sparse packing scenarios, resize the grid to be sized over
`n_active` instead of `n_rows`, and introduce a packed-to-source inverse index
mapping that the kernel can use to look up which source rows are actually
active. This eliminates wasted kernel launches on inactive rows while keeping
the same kernel logic.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: b134abab-ef3b-4adc-8ca3-c50dd21ef7ae
📒 Files selected for processing (7)
benchmarks/benchmark_pack.pyrl_engine/kernels/ops/pytorch/packing/__init__.pyrl_engine/kernels/ops/pytorch/packing/pack.pyrl_engine/kernels/ops/triton/packing/__init__.pyrl_engine/kernels/ops/triton/packing/pack.pyrl_engine/kernels/registry.pytests/test_pack.py
| try: | ||
| from rl_engine.kernels.registry import kernel_registry | ||
|
|
||
| candidate_op = kernel_registry.get_op("pack") | ||
| if candidate_op.__class__.__name__ != candidate_name: | ||
| raise RuntimeError(f"{candidate_name} backend is unavailable") | ||
|
|
||
| (cand_packed, _), candidate_ms = _time_ms( | ||
| lambda: candidate_op(hidden, mask), | ||
| config.device, | ||
| warmup=config.warmup, | ||
| repeat=config.repeat, | ||
| ) | ||
| speedup = baseline_ms / candidate_ms if candidate_ms else float("inf") | ||
| pack_drift = (cand_packed.float() - base_packed.float()).abs().max().item() | ||
|
|
||
| # (2) end-to-end peak VRAM: dense (full logits) vs pack-then-project. | ||
| flat_ids = ids.reshape(-1) | ||
| _reset_peak(config.device) | ||
| dense_logits = (hidden.reshape(-1, hidden_dim) @ lm_head) | ||
| _ = _selected_logp(dense_logits, flat_ids) | ||
| del dense_logits | ||
| _sync(config.device) | ||
| dense_logp_mem_gb = _peak_memory_gb(config.device) | ||
|
|
||
| _reset_peak(config.device) | ||
| packed_hidden, _ = candidate_op(hidden, mask) | ||
| packed_ids, _ = candidate_op(ids.unsqueeze(-1), mask) | ||
| packed_logits = packed_hidden @ lm_head | ||
| _ = _selected_logp(packed_logits, packed_ids.squeeze(-1)) | ||
| del packed_logits, packed_hidden | ||
| _sync(config.device) | ||
| packed_logp_mem_gb = _peak_memory_gb(config.device) | ||
|
|
||
| if dense_logp_mem_gb > 0: | ||
| mem_saving_pct = 100.0 * (1.0 - packed_logp_mem_gb / dense_logp_mem_gb) | ||
| except Exception as exc: | ||
| status = "blocked" | ||
| notes = f"candidate unavailable: {str(exc).splitlines()[0]}" |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major | ⚡ Quick win
Don’t swallow all benchmark failures as “candidate unavailable”.
The broad except Exception on Line 250 masks real execution regressions (kernel/runtime/math errors) as "blocked", which can silently produce misleading benchmark results.
Suggested fix
- else:
- try:
- from rl_engine.kernels.registry import kernel_registry
-
- candidate_op = kernel_registry.get_op("pack")
- if candidate_op.__class__.__name__ != candidate_name:
- raise RuntimeError(f"{candidate_name} backend is unavailable")
+ else:
+ try:
+ from rl_engine.kernels.registry import kernel_registry
+ candidate_op = kernel_registry.get_op("pack")
+ except (ImportError, RuntimeError) as exc:
+ status = "blocked"
+ notes = f"candidate unavailable: {str(exc).splitlines()[0]}"
+ candidate_op = None
- (cand_packed, _), candidate_ms = _time_ms(
- lambda: candidate_op(hidden, mask),
- config.device,
- warmup=config.warmup,
- repeat=config.repeat,
- )
- speedup = baseline_ms / candidate_ms if candidate_ms else float("inf")
- pack_drift = (cand_packed.float() - base_packed.float()).abs().max().item()
+ if candidate_op is not None:
+ (cand_packed, _), candidate_ms = _time_ms(
+ lambda: candidate_op(hidden, mask),
+ config.device,
+ warmup=config.warmup,
+ repeat=config.repeat,
+ )
+ speedup = baseline_ms / candidate_ms if candidate_ms else float("inf")
+ pack_drift = (cand_packed.float() - base_packed.float()).abs().max().item()
# (2) end-to-end peak VRAM: dense (full logits) vs pack-then-project.
flat_ids = ids.reshape(-1)
_reset_peak(config.device)
dense_logits = (hidden.reshape(-1, hidden_dim) @ lm_head)
@@
- if dense_logp_mem_gb > 0:
- mem_saving_pct = 100.0 * (1.0 - packed_logp_mem_gb / dense_logp_mem_gb)
- except Exception as exc:
- status = "blocked"
- notes = f"candidate unavailable: {str(exc).splitlines()[0]}"
+ if dense_logp_mem_gb > 0:
+ mem_saving_pct = 100.0 * (1.0 - packed_logp_mem_gb / dense_logp_mem_gb)🧰 Tools
🪛 Ruff (0.15.18)
[warning] 250-250: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@benchmarks/benchmark_pack.py` around lines 214 - 252, The broad except
Exception block that catches all exceptions and sets status to "blocked" masks
real execution errors (kernel/runtime/math failures) as if the candidate backend
is simply unavailable. Replace the generic except Exception handler with
specific exception handling that only catches exceptions that genuinely indicate
unavailability (such as ImportError for missing kernel_registry or RuntimeError
for unavailable backends), and allow other exceptions like runtime/math/kernel
errors to propagate or be handled separately so they surface as actual failures
rather than being silently masked as blocked candidates.
Source: Linters/SAST tools
| from rl_engine.kernels.ops.triton.packing.pack import TritonPackOp | ||
| from rl_engine.testing import make_synthetic_rl_kernel_batch | ||
|
|
||
| try: | ||
| import triton # noqa: F401 | ||
|
|
||
| _HAS_TRITON = True | ||
| except ImportError: # pragma: no cover | ||
| _HAS_TRITON = False | ||
|
|
||
| requires_triton_cuda = pytest.mark.skipif( | ||
| not (_HAS_TRITON and torch.cuda.is_available()), | ||
| reason="Triton pack op requires a CUDA device and Triton.", | ||
| ) |
There was a problem hiding this comment.
🩺 Stability & Availability | 🟠 Major
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify unconditional Triton operator import occurs before optional availability guard.
rg -n -C2 'from rl_engine\.kernels\.ops\.triton\.packing\.pack import TritonPackOp|try:|import triton|except ImportError' tests/test_pack.pyRepository: RL-Align/RL-Kernel
Length of output: 516
🏁 Script executed:
cat -n tests/test_pack.pyRepository: RL-Align/RL-Kernel
Length of output: 10503
Move Triton operator import into try block to prevent test collection failures without Triton.
Line 18 imports TritonPackOp unconditionally, so test collection will fail in environments without Triton before the @skipif marker is evaluated. Move the operator import into the Triton availability try block and add a fallback assignment so it can be safely referenced in conditional code.
Proposed fix
import pytest
import torch
from rl_engine.kernels.ops.pytorch.packing.pack import NativePackOp
-from rl_engine.kernels.ops.triton.packing.pack import TritonPackOp
from rl_engine.testing import make_synthetic_rl_kernel_batch
try:
import triton # noqa: F401
+ from rl_engine.kernels.ops.triton.packing.pack import TritonPackOp
_HAS_TRITON = True
except ImportError: # pragma: no cover
_HAS_TRITON = False
+ TritonPackOp = None # type: ignore[assignment]
@@
def test_registry_dispatches_pack():
from rl_engine.kernels.registry import kernel_registry
op = kernel_registry.get_op("pack")
if _HAS_TRITON and torch.cuda.is_available():
+ assert TritonPackOp is not None
assert isinstance(op, TritonPackOp)
else:
assert isinstance(op, NativePackOp)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from rl_engine.kernels.ops.triton.packing.pack import TritonPackOp | |
| from rl_engine.testing import make_synthetic_rl_kernel_batch | |
| try: | |
| import triton # noqa: F401 | |
| _HAS_TRITON = True | |
| except ImportError: # pragma: no cover | |
| _HAS_TRITON = False | |
| requires_triton_cuda = pytest.mark.skipif( | |
| not (_HAS_TRITON and torch.cuda.is_available()), | |
| reason="Triton pack op requires a CUDA device and Triton.", | |
| ) | |
| from rl_engine.kernels.ops.pytorch.packing.pack import NativePackOp | |
| from rl_engine.testing import make_synthetic_rl_kernel_batch | |
| try: | |
| import triton # noqa: F401 | |
| from rl_engine.kernels.ops.triton.packing.pack import TritonPackOp | |
| _HAS_TRITON = True | |
| except ImportError: # pragma: no cover | |
| _HAS_TRITON = False | |
| TritonPackOp = None # type: ignore[assignment] |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/test_pack.py` around lines 18 - 31, The TritonPackOp import on line 18
happens unconditionally before the Triton availability check, causing test
collection to fail in environments without Triton. Move the TritonPackOp import
into the try block alongside the triton import, and add a fallback assignment in
the except block (such as TritonPackOp = None) so the name can be safely
referenced elsewhere in the code. This ensures tests are properly skipped by the
requires_triton_cuda marker instead of failing during collection.
The second mask.dim() < 1 guard was dead code: the intervening shape check does not alter mask.dim(). Addresses CodeRabbit review on RL-Align#182.
| packed = flat_x.index_select(0, index) | ||
|
|
||
| # cu_seqlens: prefix-sum of per-row active counts, for varlen consumers. | ||
| per_row_active = mask.reshape(mask.shape[0], -1).to(torch.int64).sum(dim=1) |
There was a problem hiding this comment.
packed uses mask.to(bool), but cu_seqlens sums the raw mask. Non-bool masks can produce wrong prefix sums. Please either require bool masks or compute counts from the bool mask.
Packing selects rows via mask.to(bool) (nonzero == active), but cu_seqlens
summed the raw mask, so a non-bool mask (e.g. values in {0, 2}) inflated the
prefix sum beyond the number of rows actually packed. Count from the same
bool mask so cu_seqlens always matches the packed row count. Adds a
regression test. Addresses review feedback on RL-Align#182.
Fused masking + variable-length pack-and-pad op (#42)
What this adds
A
packoperator that compacts the active rows of a dense[B, S, *tail]tensor (selected by a
[B, S]mask) into a contiguous[Total_Active, *tail]tensor, returns per-row
cu_seqlens, and scatters gradients back to the denselayout on the backward pass.
NativePackOpSyntheticRLKernelBatch.compact_completion_values.TritonPackOpRegistry dispatch for
"pack": Triton on GPU, PyTorch fallback on CPU.Correctness
tests/test_pack.pypass on an NVIDIA H20 (SM90, CUDA 13.0).gradcheck.drift =
0.000e+00across all benchmarked shapes.Why it matters: end-to-end VRAM
In RL training only the response / non-padding tokens contribute to the loss.
Packing the hidden states before the vocab projection means the full
[B, S, V]logits are never materialized for masked-out tokens — exactly thesaving #42 targets ("saved memory can be used for larger batches or longer
CoT").
Benchmark:
hidden=4096 -> lm_head -> logits -> selected logp, bf16,B=32, S=1024, comparing dense (full logits) vs pack-then-project.vocab = 131072
vocab = 32768
(Full data:
pack_h20.csv.)How to read the numbers
response — the common RL case), packing before the projection cuts peak logp
memory by up to ~88 % (density 0.05). The
17 GB -> 2 GBdrop lets thesame GPU fit a much larger batch or longer CoT.
pack overhead), as expected — there is nothing to compact, so nothing is
saved. This confirms the measurement is honest.
speedup< 1) is reported as-is. The pack op itself ismemory-bound and a touch slower than PyTorch's boolean indexing
(
index_selectis already highly tuned). Its absolute cost is 0.06–0.35 ms,negligible against the multi-GB memory it saves. This PR targets the VRAM
win ([FEAT][kernels]: implement Fused Masking and Variable-Length Sequence Packing (Pack-and-Pad) #42's stated motivation);
Reproduce
PYTHONPATH=. python -m pytest tests/test_pack.py -v # correctness (needs CUDA for Triton cases) PYTHONPATH=. python benchmarks/benchmark_pack.py \ --num-prompts 4 --g-sizes 8 --hidden-dim 4096 \ --mask-densities 0.05,0.1,0.3,0.5,1.0 \ --completion-lens 1024 --vocab-sizes 32768,131072 \ --output benchmarks/results/pack_h20.csvNotes for reviewers
NativePackOp(CPU path); Triton tests skip cleanly.they run on ROCm via Triton without CUDA-specific intrinsics.
needs-gpu-cilabel so the GPU CI exercises the Tritonpath.
Summary by CodeRabbit
Summary by CodeRabbit
New Features
cu_seqlens, with both PyTorch and Triton GPU backends."pack"op dispatches to the best available backend per platform.Tests
cu_seqlens, unpack round-trips, gradient behavior, edge cases, and backend dispatch.Chores