Skip to content

[JAX] Improve JAX tutorial documentation#2976

Open
jberchtold-nvidia wants to merge 16 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/improve-jax-tutorial
Open

[JAX] Improve JAX tutorial documentation#2976
jberchtold-nvidia wants to merge 16 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/improve-jax-tutorial

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented May 11, 2026

Description

Reworks tutorial to focus on individual operations and their usage+performance. This will make it clearer to users the impact of each operation and they can focus on trying them out one-at-a-time depending on which are bottlenecks in their models.

Additionally, this switches from notebook .ipynb files to .rst and separate .py files for easier testing in CI to ensure our docs do not become stale and always work with the latest TE version.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Rework existing tutorial and replace with new Dense-specific tutorial
  • Placeholders for Attention and MoE
  • Refactor .ipynb notebooks to .rst and .py files for similar appearance in docs but better testability in CI by running .py files

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR replaces the old te_jax_integration.ipynb notebook with a restructured documentation suite: a new RST hub page (te_jax_integration.rst) linking to per-topic RST pages backed by testable .py source files, with CI hooks in both single-GPU and multi-GPU test runners.

  • Tutorial refactor: Replaces the monolithic notebook with focused, per-operation docs (dense.rst/dense.py) and stub placeholders for Attention, Collective GEMMs, and Expert Parallelism; the .py files are directly testable in CI via pytest.
  • Test infrastructure: Adds test_dense.py with deferred imports inside each test body (guarded by @requires_mxfp8) to prevent collection-time failures on non-Blackwell hardware, and wires both L0 single-GPU and L0 distributed runners to docs/examples/jax/.
  • Utility extension: Adds compare_fwd_bwd to quickstart_jax_utils.py for numeric correctness assertions between a baseline model and its TE-quantized counterpart.

Confidence Score: 5/5

Documentation-only restructuring with no changes to runtime library code; safe to merge.

All changed files are documentation, tutorial scripts, and CI shell scripts. The new test infrastructure correctly defers TE imports to avoid collection errors on non-Blackwell hardware, CI paths now point to the correct docs/examples/jax/ directory, and Sphinx cross-references are well-formed. The only finding is a single truncated sentence in the conventions section of the hub page.

docs/examples/te_jax_integration.rst — one incomplete bullet in the Conventions section.

Important Files Changed

Filename Overview
docs/examples/jax/dense.py New tutorial script for Dense GEMMs; clean module-level setup, correct sys.path (Python adds script directory automatically), well-structured DENSE_*_START/END markers for RST literalinclude.
docs/examples/jax/test_dense.py Pytest test file with correctly deferred from-dense imports inside each test body to avoid module-level MXFP8 init on non-Blackwell hardware; requires_mxfp8 guards are consistently applied.
docs/examples/te_jax_integration.rst New landing page RST; Sphinx label jax_recipe_table_overview defined correctly; one incomplete bullet in the Conventions section (truncated at warmup).
docs/examples/jax/quickstart_jax_utils.py Adds compare_fwd_bwd utility; correctly accesses dvars params which is safe for MXFP8BlockScaling (stateless); numpy.testing.assert_allclose used appropriately.
qa/L0_jax_unittest/test.sh Adds single-GPU pytest invocation against TE_PATH/docs/examples/jax/ (correct path); reuses existing pytest.ini from tests/jax/.
qa/L0_jax_distributed_unittest/test.sh Adds multi-GPU pytest with -k multi_gpu filter against TE_PATH/docs/examples/jax/; correct path and graceful auto-skip when fewer than 4 GPUs are present.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[docs/index.rst] -->|toctree| B[te_jax_integration.rst\nhub page]
    B -->|toctree + link| C[jax/dense.rst\nAvailable]
    B -->|toctree + link| D[jax/collective_gemm.rst\nComing soon]
    B -->|toctree + link| E[jax/attention.rst\nComing soon]
    B -->|toctree + link| F[jax/expert_parallelism.rst\nComing soon]
    C -->|literalinclude| G[jax/dense.py\ntutorial source]
    C -->|literalinclude| H[jax/dense.out\npre-captured output]
    G -->|import| I[jax/quickstart_jax_utils.py]
    J[jax/test_dense.py\npytest entry points] -->|deferred import| G
    J -->|import| I
    K[qa/L0_jax_unittest/test.sh] -->|pytest| J
    L[qa/L0_jax_distributed_unittest/test.sh] -->|pytest -k multi_gpu| J
Loading

Reviews (10): Last reviewed commit: "Update docs/examples/jax/dense.rst" | Re-trigger Greptile

Comment thread docs/examples/jax_examples/dense.ipynb Outdated
Comment thread docs/examples/jax_examples/dense.ipynb Outdated
Comment thread docs/examples/jax_examples/attention.ipynb Outdated
Comment thread docs/examples/jax_examples/dense.ipynb Outdated
Comment thread docs/examples/jax_examples/moe.ipynb Outdated
Comment thread docs/examples/jax_examples/dense.ipynb Outdated
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Comment thread docs/examples/jax/dense.py
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 L0

Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this skeleton.
I like the modular approach, concise explanation and benchmarking.

In general it looks good there might be some working around needed on item placements but I think that's going to be an evolving process.

Comment on lines +1 to +11
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

See LICENSE for license information.

JAX: Attention with TransformerEngine
=====================================

**TODO — Coming soon.**

`← Back to the JAX integration overview <../te_jax_integration.html>`_
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to attention but looks like you are renaming the dir to examples/jax_examples whereas I think the pytorch side is examples/pytorch ?
I think we could stick with examples/jax - thoughts ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, updated to examples/jax

`Haiku/Flax interop
<https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html>`_ if you're on
a different stack.)
* **Baseline dtype.** bf16 for inputs and parameters.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add GB200 (arch) details here rather than adding it in the example module or is that by choice ?
I think there's value in having all examples run on the same arch for consistency.

Comment thread docs/examples/jax/attention.rst
Comment thread docs/examples/jax/conftest.py Outdated
#
# See LICENSE for license information.

"""Pytest conftest for docs/examples/jax_examples.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the usage of pytest in general, however I think currently the examples/mnist uses the in built Python UT module for the test example.
@phu0ngng and @tdophung it might be good to standardize and use pytest in there too - thoughts ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, in our main tests in tests/jax/ we use pytest. In our examples/jax we do use unittest instead, but then run those tests in CI with pytest examples/jax/.... because pytest can also run unittest tests.

I'm ok with standardizing and using pytest everywhere. We already have requirements.txt files for running the examples/jax/mnist or encoder tests, so we could add the pytest dependency there too.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with standardizing the use of pytest

Comment thread docs/examples/jax_examples/dense.out Outdated
Comment thread docs/examples/jax_examples/dense.rst Outdated
and your performance comparison will not be accurate.


6. Multi-GPU: DP=2 / TP=2 on a single Dense
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Single GPU performane
    4,5 ?
  2. Multi-GPU: DP=2 / TP=2 on a single Dense

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I had it broken into more sections and forgot to update the latest section numbers. Fixed now

Comment thread qa/L0_jax_unittest/test.sh Outdated
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/improve-jax-tutorial branch from 7464325 to 5432ec6 Compare May 15, 2026 16:43
Comment thread qa/L0_jax_unittest/test.sh Outdated
Comment thread qa/L1_jax_distributed_unittest/test.sh Outdated
jberchtold-nvidia and others added 2 commits May 15, 2026 09:48
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Comment thread docs/examples/jax/dense.rst Outdated
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Comment thread docs/examples/jax/test_dense.py Outdated
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/improve-jax-tutorial branch from 48884cd to 168cc63 Compare May 15, 2026 18:36
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/improve-jax-tutorial branch from 54b1a9c to 4c1fec9 Compare May 15, 2026 19:02
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

Comment thread docs/examples/jax/test_dense.py
Comment thread docs/examples/jax/dense.rst Outdated
Comment thread qa/L0_jax_unittest/test.sh
Comment thread qa/L1_jax_distributed_unittest/test.sh Outdated
Comment thread docs/examples/jax/dense.out
Comment thread docs/examples/jax/dense.rst Outdated
Comment thread docs/examples/te_jax_integration.rst Outdated
Comment thread docs/examples/jax/dense.rst Outdated
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/improve-jax-tutorial branch from 92d9e29 to c2c8444 Compare May 19, 2026 22:27
Dense — and the only code change was passing ``dot_general=te_dot_general_cls()``
into ``nn.Dense``.

The speedup depends on shape: large GEMMs benefit most. Very small GEMMs may
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we know the threshold to this. At which size do we start getting benefit

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it depends on the GPU type, if it's square or narrow, and probably the cuBLAS version. But I'll search through cuBLAS docs just in case they have anything I can link to that will stay up-to-date with their latest version's perf improvements

Comment thread docs/examples/jax/dense.rst Outdated
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

@tdophung
Copy link
Copy Markdown
Collaborator

LGTM pending CI

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants