[JAX] Improve JAX tutorial documentation#2976
Conversation
Greptile SummaryThis PR replaces the old
Confidence Score: 5/5Documentation-only restructuring with no changes to runtime library code; safe to merge. All changed files are documentation, tutorial scripts, and CI shell scripts. The new test infrastructure correctly defers TE imports to avoid collection errors on non-Blackwell hardware, CI paths now point to the correct docs/examples/jax/ directory, and Sphinx cross-references are well-formed. The only finding is a single truncated sentence in the conventions section of the hub page. docs/examples/te_jax_integration.rst — one incomplete bullet in the Conventions section. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[docs/index.rst] -->|toctree| B[te_jax_integration.rst\nhub page]
B -->|toctree + link| C[jax/dense.rst\nAvailable]
B -->|toctree + link| D[jax/collective_gemm.rst\nComing soon]
B -->|toctree + link| E[jax/attention.rst\nComing soon]
B -->|toctree + link| F[jax/expert_parallelism.rst\nComing soon]
C -->|literalinclude| G[jax/dense.py\ntutorial source]
C -->|literalinclude| H[jax/dense.out\npre-captured output]
G -->|import| I[jax/quickstart_jax_utils.py]
J[jax/test_dense.py\npytest entry points] -->|deferred import| G
J -->|import| I
K[qa/L0_jax_unittest/test.sh] -->|pytest| J
L[qa/L0_jax_distributed_unittest/test.sh] -->|pytest -k multi_gpu| J
Reviews (10): Last reviewed commit: "Update docs/examples/jax/dense.rst" | Re-trigger Greptile |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci L1 L0 |
KshitijLakhani
left a comment
There was a problem hiding this comment.
Thanks for adding this skeleton.
I like the modular approach, concise explanation and benchmarking.
In general it looks good there might be some working around needed on item placements but I think that's going to be an evolving process.
| .. | ||
| Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
|
|
||
| See LICENSE for license information. | ||
|
|
||
| JAX: Attention with TransformerEngine | ||
| ===================================== | ||
|
|
||
| **TODO — Coming soon.** | ||
|
|
||
| `← Back to the JAX integration overview <../te_jax_integration.html>`_ |
There was a problem hiding this comment.
Unrelated to attention but looks like you are renaming the dir to examples/jax_examples whereas I think the pytorch side is examples/pytorch ?
I think we could stick with examples/jax - thoughts ?
There was a problem hiding this comment.
Good point, updated to examples/jax
| `Haiku/Flax interop | ||
| <https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html>`_ if you're on | ||
| a different stack.) | ||
| * **Baseline dtype.** bf16 for inputs and parameters. |
There was a problem hiding this comment.
Should we add GB200 (arch) details here rather than adding it in the example module or is that by choice ?
I think there's value in having all examples run on the same arch for consistency.
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Pytest conftest for docs/examples/jax_examples. |
There was a problem hiding this comment.
I see, in our main tests in tests/jax/ we use pytest. In our examples/jax we do use unittest instead, but then run those tests in CI with pytest examples/jax/.... because pytest can also run unittest tests.
I'm ok with standardizing and using pytest everywhere. We already have requirements.txt files for running the examples/jax/mnist or encoder tests, so we could add the pytest dependency there too.
There was a problem hiding this comment.
I agree with standardizing the use of pytest
| and your performance comparison will not be accurate. | ||
|
|
||
|
|
||
| 6. Multi-GPU: DP=2 / TP=2 on a single Dense |
There was a problem hiding this comment.
- Single GPU performane
4,5 ? - Multi-GPU: DP=2 / TP=2 on a single Dense
There was a problem hiding this comment.
Good catch, I had it broken into more sections and forgot to update the latest section numbers. Fixed now
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
7464325 to
5432ec6
Compare
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
48884cd to
168cc63
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
54b1a9c to
4c1fec9
Compare
for more information, see https://pre-commit.ci
|
/te-ci |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
92d9e29 to
c2c8444
Compare
| Dense — and the only code change was passing ``dot_general=te_dot_general_cls()`` | ||
| into ``nn.Dense``. | ||
|
|
||
| The speedup depends on shape: large GEMMs benefit most. Very small GEMMs may |
There was a problem hiding this comment.
I wonder if we know the threshold to this. At which size do we start getting benefit
There was a problem hiding this comment.
I think it depends on the GPU type, if it's square or narrow, and probably the cuBLAS version. But I'll search through cuBLAS docs just in case they have anything I can link to that will stay up-to-date with their latest version's perf improvements
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
|
/te-ci |
|
LGTM pending CI |
Description
Reworks tutorial to focus on individual operations and their usage+performance. This will make it clearer to users the impact of each operation and they can focus on trying them out one-at-a-time depending on which are bottlenecks in their models.
Additionally, this switches from notebook
.ipynbfiles to.rstand separate.pyfiles for easier testing in CI to ensure our docs do not become stale and always work with the latest TE version.Type of change
Changes
Checklist: