Skip to content

feat(ws1): NativeSiLUOp + NativeSwiGLUOp pure-PyTorch ground-truth references + numerical contract tests#166

Open
maxiaosong1124 wants to merge 2 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-silu-swiglu-pytorch-op
Open

feat(ws1): NativeSiLUOp + NativeSwiGLUOp pure-PyTorch ground-truth references + numerical contract tests#166
maxiaosong1124 wants to merge 2 commits into
RL-Align:mainfrom
maxiaosong1124:feat/ws1-silu-swiglu-pytorch-op

Conversation

@maxiaosong1124

@maxiaosong1124 maxiaosong1124 commented Jun 21, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds the pure-PyTorch ground-truth reference ops for the gated MLP activation
of the WS1 batch-invariant forward chain: SiLU (Swish) and SwiGLU, built on
top of the numerical contract defined in #108. Ships the two ops, their registry
wiring, docs, and a 16-case test suite that pins down both alignment axes
(Axis-A bitwise batch invariance, Axis-B per-dtype path).

Refs #108

Terminology

This PR uses the WS1 alignment vocabulary from #108:

  • Axis-A — batch invariance (reproducibility). A row's output must not depend on
    how many rows share the batch (batch size, slicing, padding). Asserted bitwise
    (torch.equal). This is what keeps train-time (large batch) and sample-time
    (small batch / dynamic padding) numerics identical so the policy ratio doesn't drift.
  • Axis-B — accuracy. The low-precision (bf16 / fp16) forward path. These activations
    compute in fp32 and cast the result back to the input dtype, so the dtype output is
    bitwise equal to the fp32 formula cast down — asserted with torch.equal against
    an independent fp32 reference (no tolerance window needed for an element-wise op).

Motivation / Context

#108 establishes the ground-truth harness and numerical contract for the WS1
batch-invariant forward chain. The Qwen3-8B dense MLP is a gated (SwiGLU) MLP:

down_proj( silu(gate_proj(x)) * up_proj(x) )

This PR covers the activation stage in the middle:

  • silu — element-wise x * sigmoid(x) (hidden_act="silu"), shape-agnostic.
  • swiglusilu(gate) * up, where gate / up are the gate_proj / up_proj
    outputs at the intermediate dim (Qwen3-8B: 12288). The trailing down_proj is a
    plain matmul and lives in a separate op.

This PR provides the deterministic fp32 reference path those downstream kernels
(Triton / CUDA / ROCm fused activation) will be validated against.

Changes

  • rl_engine/kernels/ops/pytorch/activation/swiglu.pyNativeSiLUOp, NativeSwiGLUOp
    • forward() — accumulate in fp32, cast result back to input dtype (Axis-B path)
    • forward_fp32() — fp32 accumulation, forced fp32 output (ground-truth / backward golden source)
    • Formulas: silu(x) = x * sigmoid(x); swiglu(gate, up) = gate * sigmoid(gate) * up
    • Pure functions — inputs never mutated in place
    • SwiGLU shape guard: gate and up must share shape
  • rl_engine/kernels/registry.py — register PYTORCH_NATIVE_SILU /
    PYTORCH_NATIVE_SWIGLU in OpBackend and add silu / swiglu dispatch to the
    cuda / rocm / cpu priority maps
  • tests/test_swiglu.py — 16 tests (details below)
  • docs/operators/activation.md + nav / index wiring

How this satisfies the #108 contract

#108 requirement How it's met here
Deterministic reference path forward_fp32() computes element-wise in fp32; tests use fixed-seed torch.Generator so outputs are reproducible
Per-dtype tolerance policy (bitwise vs tight-tolerance) Axis-A asserted bitwise (torch.equal); Axis-B dtype output is fp32-compute-then-cast, asserted bitwise against the independent fp32 formula cast to dtype (element-wise op needs no tolerance window)
Batch-config sweep / validation helper Batch-invariance checks compute on the full batch, then assert sliced/padded rows are bitwise identical to their full-batch counterparts
Realistic shapes covered Batch-invariance tests run at the Qwen3-8B intermediate dim (12288)

Test Environment

OS Ubuntu 22.04.5 LTS (kernel 5.15.0-122-generic)
Python 3.12.3
PyTorch 2.8.0+cu128
CUDA / cuDNN 12.8 / 9.10.02 (driver 580.65.06)

──────────────────────────
padding variants, asserted bitwise

  • purity (inputs not mutated in place)
  • gradient flow (fp32 autograd = backward golden source)
  • SwiGLU shape guard fires on mismatched gate / up shapes
  • registry dispatch resolves silu → NativeSiLUOp, swiglu → NativeSwiGLUOp

Checklist

  • Pure-PyTorch reference, no custom extension required
  • SwiGLU covered at the Qwen3-8B intermediate dim (12288)
  • Axis-A bitwise batch invariance enforced
  • Axis-B fp32-compute-then-cast dtype path tested
  • Registered in OpBackend + cuda/rocm/cpu priority maps
  • All 16 tests pass locally

Summary by CodeRabbit

  • New Features

    • Added SiLU and SwiGLU activation operators with PyTorch implementations and registry support.
  • Documentation

    • Added comprehensive documentation for SiLU/SwiGLU operators, including mathematical definitions, tensor shape requirements, and backend dispatch behavior.
    • Added validation tests covering correctness, input validation, batch invariance, and gradient propagation.

WS1 ground-truth activation ops for issue RL-Align#108 (Qwen3-8B gated MLP):
- NativeSiLUOp: silu(x) = x * sigmoid(x)
- NativeSwiGLUOp: silu(gate) * up (gate/up at intermediate dim)
Both expose the forward / forward_fp32 dual-path contract (fp32
ground truth + dtype-behavior path), pure functions, fp32 accumulation.

- register PYTORCH_NATIVE_SILU / PYTORCH_NATIVE_SWIGLU in OpBackend and
  the cuda/rocm/cpu priority maps
- tests/test_swiglu.py: correctness vs fp32 formula, dtype paths,
  Axis-A batch invariance (slice + padding), purity, gradient flow,
  shape guard, registry dispatch
- docs/operators/activation.md + nav/index wiring
@coderabbitai

coderabbitai Bot commented Jun 21, 2026

Copy link
Copy Markdown

Review Change Stack

📝 Walkthrough

Walkthrough

Adds NativeSiLUOp and NativeSwiGLUOp PyTorch reference implementations with fp32-accumulation semantics and dual forward/forward_fp32 paths. Two new OpBackend enum members are registered in KernelRegistry for cuda/rocm/cpu dispatch. A 127-line test module validates correctness, invariance, purity, gradients, and registry integration. Documentation for the operator contract is added under docs/operators/activation.md.

Changes

SiLU / SwiGLU activation operators

Layer / File(s) Summary
Op implementation and registry wiring
rl_engine/kernels/ops/pytorch/activation/__init__.py, rl_engine/kernels/ops/pytorch/activation/swiglu.py, rl_engine/kernels/registry.py
NativeSiLUOp and NativeSwiGLUOp are implemented with forward (fp32 compute, cast to input dtype) and forward_fp32 (fp32 output) paths; NativeSwiGLUOp._swiglu raises ValueError on shape mismatch. OpBackend gains two new enum members, and KernelRegistry._priority_map routes "silu"/"swiglu" on all platforms to those backends.
Test suite
tests/test_swiglu.py
Covers dtype-preserving correctness against fp32 reference, shape-mismatch guard, batch/padding invariance, input purity, finite-gradient backprop, and kernel_registry.get_op dispatch for both ops.
Operator documentation and navigation
docs/operators/activation.md, docs/operators/README.md, docs/.nav.yml
Documents math formulas, tensor contract, dual-path semantics, dispatch behavior, accuracy axes, test coverage, and current limitations (no fused CUDA/Triton backend). Navigation and README index updated.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Poem

🐇 Hop! The sigmoid blooms at last,
gate meets up with fp32 cast,
SwiGLU purrs through every dtype lane,
no mutation, no broadcast pain.
The registry points, the tests all pass—
this little rabbit ships some class! ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 20.83% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: introducing two pure-PyTorch reference implementations (NativeSiLUOp and NativeSwiGLUOp) with ground-truth references and numerical contract tests.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@docs/operators/activation.md`:
- Around line 12-16: The fenced code block containing the ASCII diagram (showing
hidden, gate_proj, gate, swiglu, down_proj, and up_proj) is missing a language
identifier, which violates the MD040 markdownlint rule. Add "text" as the
language identifier to the opening fence by changing the opening triple
backticks to include the language specifier, making it ```text instead of just
```.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 52c72f53-6442-4854-927a-13addf467820

📥 Commits

Reviewing files that changed from the base of the PR and between d6db6bf and 9bcd65b.

📒 Files selected for processing (7)
  • docs/.nav.yml
  • docs/operators/README.md
  • docs/operators/activation.md
  • rl_engine/kernels/ops/pytorch/activation/__init__.py
  • rl_engine/kernels/ops/pytorch/activation/swiglu.py
  • rl_engine/kernels/registry.py
  • tests/test_swiglu.py

Comment on lines +12 to +16
```
hidden --gate_proj--> gate --\
swiglu --> down_proj --> hidden
hidden --up_proj----> up ----/
```

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add a language identifier to the fenced code block (MD040).

The diagram fence is unlabeled, which will fail markdownlint in strict docs CI.

Proposed fix
-```
+```text
 hidden --gate_proj--> gate --\
                               swiglu --> down_proj --> hidden
 hidden --up_proj----> up ----/
</details>

<!-- suggestion_start -->

<details>
<summary>📝 Committable suggestion</summary>

> ‼️ **IMPORTANT**
> Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

```suggestion

🧰 Tools
🪛 markdownlint-cli2 (0.22.1)

[warning] 12-12: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@docs/operators/activation.md` around lines 12 - 16, The fenced code block
containing the ASCII diagram (showing hidden, gate_proj, gate, swiglu,
down_proj, and up_proj) is missing a language identifier, which violates the
MD040 markdownlint rule. Add "text" as the language identifier to the opening
fence by changing the opening triple backticks to include the language
specifier, making it ```text instead of just ```.

Source: Linters/SAST tools

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants