Skip to content

feat(grpo): context-parallel (CP) loss alignment and reduction#66

Open
RUFFY-369 wants to merge 14 commits into
NousResearch:dev-updated-againfrom
RUFFY-369:infra/grpo-cp-alignment
Open

feat(grpo): context-parallel (CP) loss alignment and reduction#66
RUFFY-369 wants to merge 14 commits into
NousResearch:dev-updated-againfrom
RUFFY-369:infra/grpo-cp-alignment

Conversation

@RUFFY-369
Copy link
Copy Markdown

Summary

This PR introduces Context-Parallel (CP) Loss Alignment for GRPO. It ensures that when training with sequence parallelism (CP), the GRPO loss and advantage reductions are correctly synchronized across the CP mesh workers to maintain mathematical correctness during the backward pass.

Technical Context

In a Context Parallel configuration, a single sequence is split across multiple GPUs to manage memory. Standard loss reduction often fails to account for these split segments, leading to inconsistent gradients.

This implementation introduces a specialized CP-alignment layer that:

  • Hooks into the ContextParallel dispatcher to correctly scale loss values by the CP degree.
  • Synchronizes Advantage and KL-divergence terms across the CP mesh before they are used for weight updates.
  • Ensures bit-exact gradient parity between CP and non-CP training runs.

Key Changes

  • torchtitan/distributed/context_parallel.py: Added _enable_context_parallel_dispatcher and loss reduction hooks for GRPO.
  • torchtitan/grpo/grpo_step.py: Updated the loss computation logic to be sequence-mesh aware.
  • torchtitan/grpo/utils.py: Implementation of masked_mean and masked_sum primitives that correctly handle CP-mesh boundaries.

Modernization & Compatibility

To support modern hardware and the latest PyTorch standards, this PR includes foundational modernization for PyTorch 2.5.1+.

  • Backward Compatible: Uses try...except and version guards to remain fully compatible with the existing PyTorch 2.3/2.4 baseline in the dev-updated-again fork.
  • Experimental Namespace Support: Aligns with the refactored torch.distributed.tensor.experimental._attention namespace in PyTorch 2.5+.

Verification Results (vast.ai)

  • Hardware Profile: Verified on a vast.ai cluster with 2x RTX 3090 GPUs (24GB VRAM).
  • Scale: Tested with CP-degree 2, verifying that sequence-split training produces identical weights to single-seq training.
  • Tests: Successfully ran scripts/verify_grpo_2gpu.sh, confirming that CP-mesh synchronization occurs without deadlocks.
  • Cluster Stability: Verified that CP collectives for loss alignment do not introduce memory fragmentation on 24GB VRAM units.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant