feat(ws1): NativeSiLUOp + NativeSwiGLUOp pure-PyTorch ground-truth references + numerical contract tests#166
Conversation
WS1 ground-truth activation ops for issue RL-Align#108 (Qwen3-8B gated MLP): - NativeSiLUOp: silu(x) = x * sigmoid(x) - NativeSwiGLUOp: silu(gate) * up (gate/up at intermediate dim) Both expose the forward / forward_fp32 dual-path contract (fp32 ground truth + dtype-behavior path), pure functions, fp32 accumulation. - register PYTORCH_NATIVE_SILU / PYTORCH_NATIVE_SWIGLU in OpBackend and the cuda/rocm/cpu priority maps - tests/test_swiglu.py: correctness vs fp32 formula, dtype paths, Axis-A batch invariance (slice + padding), purity, gradient flow, shape guard, registry dispatch - docs/operators/activation.md + nav/index wiring
📝 WalkthroughWalkthroughAdds ChangesSiLU / SwiGLU activation operators
Estimated code review effort🎯 2 (Simple) | ⏱️ ~12 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@docs/operators/activation.md`:
- Around line 12-16: The fenced code block containing the ASCII diagram (showing
hidden, gate_proj, gate, swiglu, down_proj, and up_proj) is missing a language
identifier, which violates the MD040 markdownlint rule. Add "text" as the
language identifier to the opening fence by changing the opening triple
backticks to include the language specifier, making it ```text instead of just
```.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 52c72f53-6442-4854-927a-13addf467820
📒 Files selected for processing (7)
docs/.nav.ymldocs/operators/README.mddocs/operators/activation.mdrl_engine/kernels/ops/pytorch/activation/__init__.pyrl_engine/kernels/ops/pytorch/activation/swiglu.pyrl_engine/kernels/registry.pytests/test_swiglu.py
| ``` | ||
| hidden --gate_proj--> gate --\ | ||
| swiglu --> down_proj --> hidden | ||
| hidden --up_proj----> up ----/ | ||
| ``` |
There was a problem hiding this comment.
Add a language identifier to the fenced code block (MD040).
The diagram fence is unlabeled, which will fail markdownlint in strict docs CI.
Proposed fix
-```
+```text
hidden --gate_proj--> gate --\
swiglu --> down_proj --> hidden
hidden --up_proj----> up ----/</details>
<!-- suggestion_start -->
<details>
<summary>📝 Committable suggestion</summary>
> ‼️ **IMPORTANT**
> Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
```suggestion
🧰 Tools
🪛 markdownlint-cli2 (0.22.1)
[warning] 12-12: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@docs/operators/activation.md` around lines 12 - 16, The fenced code block
containing the ASCII diagram (showing hidden, gate_proj, gate, swiglu,
down_proj, and up_proj) is missing a language identifier, which violates the
MD040 markdownlint rule. Add "text" as the language identifier to the opening
fence by changing the opening triple backticks to include the language
specifier, making it ```text instead of just ```.
Source: Linters/SAST tools
Summary
Adds the pure-PyTorch ground-truth reference ops for the gated MLP activation
of the WS1 batch-invariant forward chain: SiLU (Swish) and SwiGLU, built on
top of the numerical contract defined in #108. Ships the two ops, their registry
wiring, docs, and a 16-case test suite that pins down both alignment axes
(Axis-A bitwise batch invariance, Axis-B per-dtype path).
Refs #108
Terminology
This PR uses the WS1 alignment vocabulary from #108:
how many rows share the batch (batch size, slicing, padding). Asserted bitwise
(
torch.equal). This is what keeps train-time (large batch) and sample-time(small batch / dynamic padding) numerics identical so the policy ratio doesn't drift.
compute in fp32 and cast the result back to the input dtype, so the dtype output is
bitwise equal to the fp32 formula cast down — asserted with
torch.equalagainstan independent fp32 reference (no tolerance window needed for an element-wise op).
Motivation / Context
#108 establishes the ground-truth harness and numerical contract for the WS1
batch-invariant forward chain. The Qwen3-8B dense MLP is a gated (SwiGLU) MLP:
down_proj( silu(gate_proj(x)) * up_proj(x) )
This PR covers the activation stage in the middle:
silu— element-wisex * sigmoid(x)(hidden_act="silu"), shape-agnostic.swiglu—silu(gate) * up, wheregate/upare the gate_proj / up_projoutputs at the intermediate dim (Qwen3-8B: 12288). The trailing
down_projis aplain matmul and lives in a separate op.
This PR provides the deterministic fp32 reference path those downstream kernels
(Triton / CUDA / ROCm fused activation) will be validated against.
Changes
rl_engine/kernels/ops/pytorch/activation/swiglu.py—NativeSiLUOp,NativeSwiGLUOpforward()— accumulate in fp32, cast result back to input dtype (Axis-B path)forward_fp32()— fp32 accumulation, forced fp32 output (ground-truth / backward golden source)silu(x) = x * sigmoid(x);swiglu(gate, up) = gate * sigmoid(gate) * upgateandupmust share shaperl_engine/kernels/registry.py— registerPYTORCH_NATIVE_SILU/PYTORCH_NATIVE_SWIGLUinOpBackendand addsilu/swigludispatch to thecuda / rocm / cpu priority maps
tests/test_swiglu.py— 16 tests (details below)docs/operators/activation.md+ nav / index wiringHow this satisfies the #108 contract
forward_fp32()computes element-wise in fp32; tests use fixed-seedtorch.Generatorso outputs are reproducibletorch.equal); Axis-B dtype output is fp32-compute-then-cast, asserted bitwise against the independent fp32 formula cast to dtype (element-wise op needs no tolerance window)12288)Test Environment
──────────────────────────
padding variants, asserted bitwise
Checklist
Summary by CodeRabbit
New Features
Documentation