Skip to content

Commit 878e230

Browse files
JAX Support (#172)
* Bare bones of a JAX extension. * Continued progress. * More changes. * Added all relevant parameters. * Should be able to compile a kernel. * Cleaned up code a bit. * Reorganized the repo. * Refactored imports. * Tests are passing after the first refactor. * Nested directory structure by one level. * Temp commit. * Got the editable install working again. * Extension module in progress. * More things working. * Began putting together a test rig. * Things starting to work. * Made LoopUnrollTP generic. * More things are working. * More plumbing. * More progress. * Dispatch complete. * Forward call is working. * Added the backward pass. * Encapsulated the forward call. * Skeleton of rule implemented. * Backward call is working. * Zero'd buffer. * Wrapped the double-backward pass. * Added the forward convolution implementation. * Backward convolution implemented. * Convolution double backward registered. * Finished the double backward VJP registration. * Double backward pass seems to work. * Did some extra testing. * Reorg of LoopUnrollConv.py * Convolution changed. * Finished prototype of TensorProductConv. * Added some type annotations. * Finished the forward call. * Ready to start JAX support. * More plumbing. * Forward call is working. * Registered the VJP rules for backward and double-backward. * Added __call__ functions. * Prepping to add tests. * Ran ruff. * Moved tests back. * 1/3 tests is passing. * Backward test is passing. * Backward convolution is failing, need to figure out why. * Zerod gradient buffer. * Abstracted away reordering. * Added JAX reordering function. * Reordering starting to work... * Forward and backward are working. * Batch test is working. * Ready to modify the double backward correctness function. * Correctness double backward works for existing code, need to extend to JAX. * Wrote double backward function for JAX. * All double backward tests passing. * Added the mixins. * Added double backward CPU function to jax TP conv. * Almost there, need to get TensorProductConv working. * Double backward tests are passing. * Updated documentation. * Modified documentation. * Updated documentation. * More documentation progress. * Renamed. * Renaming + added JAX example. * JAX example. * Added examples. * Updated README. * Updated release file. * Linted. * Updated the build verification. * Merge complete. * Updated README. * Fixed some minor issues. * Added symlinks. * Cleaning up the core. * More core cleanup. * Rename. * Example test is working. * Sanded away some more issues. * Updated changelog. * Pre-commit. * Download XLA directly. * Removed need for build isolation. * Removed need for build isolation. * Updated README. * Updated documentation slightly. * Don't need extension source path anymore. * Removed a spurious import. * Update Python version and XLA Git tag in CMakeLists * Update XLA dir. * Removed dependency. * Went back to version that disables build isolation. * Updated README. * Updated error handling * Last bit of cleanup. * Ruff. * Things working for HIP, just need to branch. * Updated CMakeLists. * Added pyproject.toml define. * Plumbed logic. * Made things compile with HIP. * Updated READMEs. * Highlight AMD support in changelog. * Ruff. * Updated documentation. * Updated installation instructions. * More ruff. * Added option for CI. * Ready to go. * Enabled direct download XLA.
1 parent e723ea7 commit 878e230

99 files changed

Lines changed: 3730 additions & 1729 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/docs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
- name: Install dependencies
2424
run: |
2525
python -m pip install --upgrade pip
26-
pip install sphinx furo
26+
pip install -r docs/requirements.txt
2727
- name: Build website
2828
run: |
2929
sphinx-build -M dirhtml docs docs/_build

.github/workflows/release.yaml

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66
# ref: https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/
77

88
jobs:
9-
build:
9+
build-oeq:
1010
name: Build distribution
1111
runs-on: ubuntu-latest
1212
steps:
@@ -16,21 +16,20 @@ jobs:
1616
with:
1717
python-version: '3.10'
1818
- name: install dependencies, then build source tarball
19-
run: |
19+
run: |
20+
cd openequivariance
2021
python3 -m pip install build --user
2122
python3 -m build --sdist
2223
- name: store the distribution packages
2324
uses: actions/upload-artifact@v4
2425
with:
2526
name: python-package-distributions
26-
path: dist/
27+
path: openequivariance/dist/
2728

2829
pypi-publish:
2930
name: Upload release to PyPI
3031
runs-on: ubuntu-latest
31-
# build task to be completed first
32-
needs: build
33-
# Specifying a GitHub environment is optional, but strongly encouraged
32+
needs: build-oeq
3433
environment:
3534
name: pypi
3635
url: https://pypi.org/p/openequivariance
@@ -42,6 +41,47 @@ jobs:
4241
uses: actions/download-artifact@v4
4342
with:
4443
name: python-package-distributions
45-
path: dist/
44+
path: openequivariance/dist/
45+
- name: publish package distributions to PyPI
46+
uses: pypa/gh-action-pypi-publish@release/v1
47+
48+
# ------------------------------------
49+
50+
build-oeq-extjax:
51+
name: Build distribution
52+
runs-on: ubuntu-latest
53+
steps:
54+
- uses: actions/checkout@v4
55+
- name: set up Python
56+
uses: actions/setup-python@v5
57+
with:
58+
python-version: '3.10'
59+
- name: install dependencies, then build source tarball
60+
run: |
61+
cd openequivariance_extjax
62+
python3 -m pip install build --user
63+
python3 -m build --sdist
64+
- name: store the distribution packages
65+
uses: actions/upload-artifact@v4
66+
with:
67+
name: python-package-distributions
68+
path: openequivariance_extjax/dist/
69+
70+
pypi-publish-extjax:
71+
name: Upload release to PyPI
72+
runs-on: ubuntu-latest
73+
needs: build-oeq-extjax
74+
environment:
75+
name: pypi
76+
url: https://pypi.org/p/openequivariance_extjax
77+
permissions:
78+
# IMPORTANT: this permission is mandatory for Trusted Publishing
79+
id-token: write
80+
steps:
81+
- name: download the distributions
82+
uses: actions/download-artifact@v4
83+
with:
84+
name: python-package-distributions
85+
path: openequivariance_extjax/dist/
4686
- name: publish package distributions to PyPI
4787
uses: pypa/gh-action-pypi-publish@release/v1
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
numpy==2.2.5
22
torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128
33
pytest==8.3.5
4-
ninja==1.11.1.4
4+
ninja==1.11.1.4
5+
nanobind==2.10.2
6+
scikit-build-core==0.11.6

.github/workflows/verify_extension_build.yml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: OEQ CUDA C++ Extension Build Verification
1+
name: OEQ C++ Extension Build Verification
22

33
on:
44
push:
@@ -29,10 +29,14 @@ jobs:
2929
sudo apt-get update
3030
sudo apt install nvidia-cuda-toolkit
3131
pip install -r .github/workflows/requirements_cuda_ci.txt
32-
pip install -e .
32+
pip install -e "./openequivariance"
3333
34-
- name: Test extension build via import
34+
- name: Test CUDA extension build via import
3535
run: |
3636
pytest \
3737
tests/import_test.py::test_extension_built \
38-
tests/import_test.py::test_torch_extension_built
38+
tests/import_test.py::test_torch_extension_built
39+
40+
- name: Test JAX extension build
41+
run: |
42+
XLA_DIRECT_DOWNLOAD=1 pip install -e "./openequivariance_extjax" --no-build-isolation

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ triton_autotuning
3838
paper_benchmarks
3939
paper_benchmarks_v2
4040
paper_benchmarks_v3
41-
openequivariance/extlib/*.so
4241

4342
get_node.sh
4443
*.egg-info

CHANGELOG.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
## Latest Changes
22

3+
### v0.5.0 (2025-12-25)
4+
JAX support is now available in
5+
OpenEquivariance for BOTH NVIDIA and
6+
AMD GPUs! See the
7+
[documentation](https://passionlab.github.io/OpenEquivariance/)
8+
and README.md for instructions on installation
9+
and usage.
10+
11+
Minor changes:
12+
- Defer error reporting when CUDA is not available
13+
to the first library usage in code, not library load.
14+
315
### v0.4.1 (2025-09-04)
416
Minor update, fixes a bug loading JIT-compiled modules
517
with PyTorch 2.9.

MANIFEST.in

Lines changed: 0 additions & 10 deletions
This file was deleted.

README.md

Lines changed: 66 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# OpenEquivariance
2-
[![OEQ CUDA C++ Extension Build Verification](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml/badge.svg?event=push)](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml)
2+
[![OEQ C++ Extension Build Verification](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml/badge.svg?event=push)](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml)
33
[![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause)
44

5-
[[Examples]](#show-me-some-examples)
5+
[[PyTorch Examples]](#pytorch-examples)
6+
[[JAX Examples]](#jax-examples)
67
[[Citation and Acknowledgements]](#citation-and-acknowledgements)
78

89
OpenEquivariance is a CUDA and HIP kernel generator for the Clebsch-Gordon tensor product,
@@ -12,8 +13,8 @@ that [e3nn](https://e3nn.org/) supports
1213
commonly found in graph neural networks
1314
(e.g. [Nequip](https://github.com/mir-group/nequip) or
1415
[MACE](https://github.com/ACEsuit/mace)). To get
15-
started, ensure that you have GCC 9+ on your system
16-
and install our package via
16+
started with PyTorch, ensure that you have PyTorch
17+
and GCC 9+ available before installing our package via
1718

1819
```bash
1920
pip install openequivariance
@@ -29,11 +30,26 @@ computation and memory consumption significantly.
2930
For detailed instructions on tests, benchmarks, MACE / Nequip, and our API,
3031
check out the [documentation](https://passionlab.github.io/OpenEquivariance).
3132

32-
📣 📣 OpenEquivariance was accepted to the 2025 SIAM Conference on Applied and
33-
Computational Discrete Algorithms (Proceedings Track)! Catch the talk in
34-
Montréal and check out the [camera-ready copy on Arxiv](https://arxiv.org/abs/2501.13986) (available May 12, 2025).
33+
⭐️ **JAX**: Our latest update brings
34+
support for JAX. For NVIDIA GPUs,
35+
install it (after installing JAX)
36+
with the following two commands strictly in order:
3537

36-
## Show me some examples
38+
``` bash
39+
pip install openequivariance[jax]
40+
pip install openequivariance_extjax --no-build-isolation
41+
```
42+
43+
For AMD GPUs:
44+
``` bash
45+
pip install openequivariance[jax]
46+
JAX_HIP=1 pip install openequivariance_extjax --no-build-isolation
47+
```
48+
49+
See the section below for example usage and
50+
our [API page](https://passionlab.github.io/OpenEquivariance/api/) for more details.
51+
52+
## PyTorch Examples
3753
Here's a CG tensor product implemented by e3nn:
3854

3955
```python
@@ -127,6 +143,48 @@ print(torch.norm(Z))
127143
`deterministic=False`, the `sender` and `receiver` indices can have
128144
arbitrary order.
129145

146+
## JAX Examples
147+
After installation, use the library
148+
as follows. Set `OEQ_NOTORCH=1`
149+
in your environment to avoid the PyTorch import in
150+
the regular `openequivariance` package.
151+
```python
152+
import jax
153+
import os
154+
155+
os.environ["OEQ_NOTORCH"] = "1"
156+
import openequivariance as oeq
157+
158+
seed = 42
159+
key = jax.random.PRNGKey(seed)
160+
161+
batch_size = 1000
162+
X_ir, Y_ir, Z_ir = oeq.Irreps("1x2e"), oeq.Irreps("1x3e"), oeq.Irreps("1x2e")
163+
problem = oeq.TPProblem(X_ir, Y_ir, Z_ir, [(0, 0, 0, "uvu", True)], shared_weights=False, internal_weights=False)
164+
165+
166+
node_ct, nonzero_ct = 3, 4
167+
edge_index = jax.numpy.array(
168+
[
169+
[0, 1, 1, 2],
170+
[1, 0, 2, 1],
171+
],
172+
dtype=jax.numpy.int32, # NOTE: This int32, not int64
173+
)
174+
175+
X = jax.random.uniform(key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
176+
Y = jax.random.uniform(key, shape=(nonzero_ct, Y_ir.dim),
177+
minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
178+
W = jax.random.uniform(key, shape=(nonzero_ct, problem.weight_numel),
179+
minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
180+
181+
tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False)
182+
Z = tp_conv.forward(
183+
X, Y, W, edge_index[0], edge_index[1]
184+
)
185+
print(jax.numpy.linalg.norm(Z))
186+
```
187+
130188
## Citation and Acknowledgements
131189
If you find this code useful, please cite our paper:
132190

docs/api.rst

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ OpenEquivariance API
88
OpenEquivariance exposes two key classes: :py:class:`openequivariance.TensorProduct`, which replaces
99
``o3.TensorProduct`` from e3nn, and :py:class:`openequivariance.TensorProductConv`, which fuses
1010
the CG tensor product with a subsequent graph convolution. Initializing either class triggers
11-
JIT compilation of a custom kernel, which can take a few seconds.
11+
JIT compilation of a custom kernel, which can take a few seconds.
1212

1313
Both classes require a configuration object specified
1414
by :py:class:`openequivariance.TPProblem`, which has a constructor
@@ -17,6 +17,9 @@ We recommend reading the `e3nn documentation <https://docs.e3nn.org/en/latest/>`
1717
trying our code. OpenEquivariance cannot accelerate all tensor products; see
1818
:doc:`this page </supported_ops>` for a list of supported configurations.
1919

20+
PyTorch API
21+
------------------------
22+
2023
.. autoclass:: openequivariance.TensorProduct
2124
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn, to
2225
:undoc-members:
@@ -27,14 +30,39 @@ trying our code. OpenEquivariance cannot accelerate all tensor products; see
2730
:undoc-members:
2831
:exclude-members: name
2932

30-
.. autoclass:: openequivariance.TPProblem
31-
:members:
32-
:undoc-members:
33-
3433
.. autofunction:: openequivariance.torch_to_oeq_dtype
3534

3635
.. autofunction:: openequivariance.torch_ext_so_path
3736

37+
JAX API
38+
------------------------
39+
The JAX API consists of ``TensorProduct`` and ``TensorProductConv``
40+
classes that behave identically to their PyTorch counterparts. These classes
41+
do not conform exactly to the e3nn-jax API, but perform the same computation.
42+
43+
If you plan to use ``oeq.jax`` without PyTorch installed,
44+
you need to set ``OEQ_NOTORCH=1`` in your local environment (within Python,
45+
``os.environ["OEQ_NOTORCH"] = 1``). For the moment, we require this to avoid
46+
breaking the PyTorch version of OpenEquivariance.
47+
48+
49+
.. autoclass:: openequivariance.jax.TensorProduct
50+
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn
51+
:undoc-members:
52+
:exclude-members:
53+
54+
.. autoclass:: openequivariance.jax.TensorProductConv
55+
:members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn
56+
:undoc-members:
57+
:exclude-members:
58+
59+
Common API
60+
---------------------
61+
62+
.. autoclass:: openequivariance.TPProblem
63+
:members:
64+
:undoc-members:
65+
3866
API Identical to e3nn
3967
---------------------
4068

docs/conf.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,17 @@
2828
html_theme = "furo"
2929
# html_static_path = ["_static"]
3030

31-
extensions = [
32-
"sphinx.ext.autodoc",
31+
extensions = ["sphinx.ext.autodoc", "sphinx_inline_tabs"]
32+
33+
sys.path.insert(0, str(Path("../openequivariance").resolve()))
34+
35+
autodoc_mock_imports = [
36+
"torch",
37+
"jax",
38+
"openequivariance._torch.extlib",
39+
"openequivariance.jax.extlib",
40+
"openequivariance_extjax",
41+
"jinja2",
42+
"numpy",
3343
]
34-
35-
sys.path.insert(0, str(Path("..").resolve()))
36-
37-
autodoc_mock_imports = ["torch", "openequivariance.extlib", "jinja2", "numpy"]
3844
autodoc_typehints = "description"

0 commit comments

Comments
 (0)