33[ ![ License] ( https://img.shields.io/badge/License-BSD_3--Clause-blue.svg )] ( https://opensource.org/licenses/BSD-3-Clause )
44
55[[ Examples]] ( #show-me-some-examples )
6+ [[ JAX Examples]] ( #jax-examples )
67[[ Citation and Acknowledgements]] ( #citation-and-acknowledgements )
78
89OpenEquivariance is a CUDA and HIP kernel generator for the Clebsch-Gordon tensor product,
@@ -29,48 +30,15 @@ computation and memory consumption significantly.
2930For detailed instructions on tests, benchmarks, MACE / Nequip, and our API,
3031check out the [ documentation] ( https://passionlab.github.io/OpenEquivariance ) .
3132
32- ⭐️ ** JAX Support** : Our latest update brings
33- support for JAX. To install, execute the following commands in order:
33+ ⭐️ ** JAX** : Our latest update brings
34+ support for JAX. To install, execute the following
35+ commands in order:
3436
3537```
3638pip install openequivariance[jax]
3739pip install openequivariance_extjax --no-build-isolation
3840```
39-
40- ``` python
41- os.environ[" OEQ_NOTORCH" ] = " 1"
42- import openequivariance as oeq
43- import jax
44-
45- seed = 42
46- key = jax.random.PRNGKey(seed)
47-
48- node_ct, nonzero_ct = 3 , 4
49- edge_index = jax.numpy.array(
50- [
51- [0 , 1 , 1 , 2 ],
52- [1 , 0 , 2 , 1 ],
53- ],
54- dtype = jax.numpy.int32, # NOTE : This int32, not int64
55- )
56-
57- X = jax.random.uniform(key, shape = (node_ct, X_ir.dim), minval = 0.0 , maxval = 1.0 , dtype = jax.numpy.float32)
58- Y = jax.random.uniform(key, shape = (nonzero_ct, Y_ir.dim),
59- minval = 0.0 , maxval = 1.0 , dtype = jax.numpy.float32)
60- W = jax.random.uniform(key, shape = (nonzero_ct, problem.weight_numel),
61- minval = 0.0 , maxval = 1.0 , dtype = jax.numpy.float32)
62-
63- # Reuse problem from earlier
64- tp_conv = oeq.jax.TensorProductConv(problem, deterministic = False )
65- Z = tp_conv.forward(
66- X, Y, W, edge_index[0 ], edge_index[1 ]
67- )
68- print (jax.numpy.linalg.norm(Z))
69- ```
70-
71- 📣 📣 OpenEquivariance was accepted to the 2025 SIAM Conference on Applied and
72- Computational Discrete Algorithms (Proceedings Track)! Catch the talk in
73- Montréal and check out the [ camera-ready copy on Arxiv] ( https://arxiv.org/abs/2501.13986 ) (available May 12, 2025).
41+ See below for example usage.
7442
7543## Show me some examples
7644Here's a CG tensor product implemented by e3nn:
@@ -166,6 +134,45 @@ print(torch.norm(Z))
166134` deterministic=False ` , the ` sender ` and ` receiver ` indices can have
167135arbitrary order.
168136
137+ ## JAX Examples
138+ After installation, use the library
139+ as follows. Set ` OEQ_NOTORCH=1 `
140+ in your environment to avoid the PyTorch import in
141+ the regular ` openequivariance ` package.
142+ ``` python
143+ import jax
144+ import os
145+
146+ os.environ[" OEQ_NOTORCH" ] = " 1"
147+ import openequivariance as oeq
148+
149+ seed = 42
150+ key = jax.random.PRNGKey(seed)
151+
152+ node_ct, nonzero_ct = 3 , 4
153+ edge_index = jax.numpy.array(
154+ [
155+ [0 , 1 , 1 , 2 ],
156+ [1 , 0 , 2 , 1 ],
157+ ],
158+ dtype = jax.numpy.int32, # NOTE : This int32, not int64
159+ )
160+
161+ X = jax.random.uniform(key, shape = (node_ct, X_ir.dim), minval = 0.0 , maxval = 1.0 , dtype = jax.numpy.float32)
162+ Y = jax.random.uniform(key, shape = (nonzero_ct, Y_ir.dim),
163+ minval = 0.0 , maxval = 1.0 , dtype = jax.numpy.float32)
164+ W = jax.random.uniform(key, shape = (nonzero_ct, problem.weight_numel),
165+ minval = 0.0 , maxval = 1.0 , dtype = jax.numpy.float32)
166+
167+ # Reuse problem from earlier
168+ # ...
169+ tp_conv = oeq.jax.TensorProductConv(problem, deterministic = False )
170+ Z = tp_conv.forward(
171+ X, Y, W, edge_index[0 ], edge_index[1 ]
172+ )
173+ print (jax.numpy.linalg.norm(Z))
174+ ```
175+
169176## Citation and Acknowledgements
170177If you find this code useful, please cite our paper:
171178
0 commit comments