Skip to content

Fix DPT decoder bugs for Scenic parity#31

Open
bingyic wants to merge 3 commits into
mainfrom
fix-decoders-scenic-parity
Open

Fix DPT decoder bugs for Scenic parity#31
bingyic wants to merge 3 commits into
mainfrom
fix-decoders-scenic-parity

Conversation

@bingyic
Copy link
Copy Markdown
Collaborator

@bingyic bingyic commented May 11, 2026

Summary

Fixes multiple bugs in pytorch/decoders.py to achieve numerical parity (max diff < 1e-4) with the Scenic/Flax reference implementation.

Changes

  1. DPTHead: Add output_activation parameter (default False).
    When True, applies F.relu() after project conv, matching Scenic.

  2. DepthDecoder: Replace with classification-based depth prediction:

    • nn.Linear(channels, num_depth_bins) head
    • bin_centers buffer via torch.linspace(min_depth, max_depth, num_depth_bins)
    • Forward: relu(logits) + min_depth → normalize → einsum(probs, bin_centers)
  3. ReassembleBlocks: Use F.gelu(x, approximate='tanh') to match JAX default.

  4. ConvTranspose kernel: Apply 180° spatial flip during Flax→PyTorch conversion.

  5. load_decoder_weights(): Unified weight loading from Scenic .zip checkpoints
    with auto-detection and key remapping for all decoder types:
    pixel_segmentation, pixel_depth_classif, pixel_normalshead

Verification

All three decoder types verified for numerical parity against Scenic reference:

  • Normals: max diff < 1e-4 ✅
  • Depth (classification, 256 bins): max diff < 1e-4 ✅
  • Segmentation: max diff < 1e-4 ✅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant