Skip to content

Commit 9d8f5d8

Browse files
committed
Updated README.
1 parent ab87185 commit 9d8f5d8

1 file changed

Lines changed: 44 additions & 37 deletions

File tree

README.md

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
[![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause)
44

55
[[Examples]](#show-me-some-examples)
6+
[[JAX Examples]](#jax-examples)
67
[[Citation and Acknowledgements]](#citation-and-acknowledgements)
78

89
OpenEquivariance is a CUDA and HIP kernel generator for the Clebsch-Gordon tensor product,
@@ -29,48 +30,15 @@ computation and memory consumption significantly.
2930
For detailed instructions on tests, benchmarks, MACE / Nequip, and our API,
3031
check out the [documentation](https://passionlab.github.io/OpenEquivariance).
3132

32-
⭐️ **JAX Support**: Our latest update brings
33-
support for JAX. To install, execute the following commands in order:
33+
⭐️ **JAX**: Our latest update brings
34+
support for JAX. To install, execute the following
35+
commands in order:
3436

3537
```
3638
pip install openequivariance[jax]
3739
pip install openequivariance_extjax --no-build-isolation
3840
```
39-
40-
```python
41-
os.environ["OEQ_NOTORCH"] = "1"
42-
import openequivariance as oeq
43-
import jax
44-
45-
seed = 42
46-
key = jax.random.PRNGKey(seed)
47-
48-
node_ct, nonzero_ct = 3, 4
49-
edge_index = jax.numpy.array(
50-
[
51-
[0, 1, 1, 2],
52-
[1, 0, 2, 1],
53-
],
54-
dtype=jax.numpy.int32, # NOTE: This int32, not int64
55-
)
56-
57-
X = jax.random.uniform(key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
58-
Y = jax.random.uniform(key, shape=(nonzero_ct, Y_ir.dim),
59-
minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
60-
W = jax.random.uniform(key, shape=(nonzero_ct, problem.weight_numel),
61-
minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
62-
63-
# Reuse problem from earlier
64-
tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False)
65-
Z = tp_conv.forward(
66-
X, Y, W, edge_index[0], edge_index[1]
67-
)
68-
print(jax.numpy.linalg.norm(Z))
69-
```
70-
71-
📣 📣 OpenEquivariance was accepted to the 2025 SIAM Conference on Applied and
72-
Computational Discrete Algorithms (Proceedings Track)! Catch the talk in
73-
Montréal and check out the [camera-ready copy on Arxiv](https://arxiv.org/abs/2501.13986) (available May 12, 2025).
41+
See below for example usage.
7442

7543
## Show me some examples
7644
Here's a CG tensor product implemented by e3nn:
@@ -166,6 +134,45 @@ print(torch.norm(Z))
166134
`deterministic=False`, the `sender` and `receiver` indices can have
167135
arbitrary order.
168136

137+
## JAX Examples
138+
After installation, use the library
139+
as follows. Set `OEQ_NOTORCH=1`
140+
in your environment to avoid the PyTorch import in
141+
the regular `openequivariance` package.
142+
```python
143+
import jax
144+
import os
145+
146+
os.environ["OEQ_NOTORCH"] = "1"
147+
import openequivariance as oeq
148+
149+
seed = 42
150+
key = jax.random.PRNGKey(seed)
151+
152+
node_ct, nonzero_ct = 3, 4
153+
edge_index = jax.numpy.array(
154+
[
155+
[0, 1, 1, 2],
156+
[1, 0, 2, 1],
157+
],
158+
dtype=jax.numpy.int32, # NOTE: This int32, not int64
159+
)
160+
161+
X = jax.random.uniform(key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
162+
Y = jax.random.uniform(key, shape=(nonzero_ct, Y_ir.dim),
163+
minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
164+
W = jax.random.uniform(key, shape=(nonzero_ct, problem.weight_numel),
165+
minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
166+
167+
# Reuse problem from earlier
168+
# ...
169+
tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False)
170+
Z = tp_conv.forward(
171+
X, Y, W, edge_index[0], edge_index[1]
172+
)
173+
print(jax.numpy.linalg.norm(Z))
174+
```
175+
169176
## Citation and Acknowledgements
170177
If you find this code useful, please cite our paper:
171178

0 commit comments

Comments
 (0)