Skip to content

Commit ab87185

Browse files
committed
Added examples.
1 parent 259ea20 commit ab87185

2 files changed

Lines changed: 30 additions & 2 deletions

File tree

README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,36 @@ pip install openequivariance[jax]
3737
pip install openequivariance_extjax --no-build-isolation
3838
```
3939

40+
```python
41+
os.environ["OEQ_NOTORCH"] = "1"
42+
import openequivariance as oeq
43+
import jax
44+
45+
seed = 42
46+
key = jax.random.PRNGKey(seed)
47+
48+
node_ct, nonzero_ct = 3, 4
49+
edge_index = jax.numpy.array(
50+
[
51+
[0, 1, 1, 2],
52+
[1, 0, 2, 1],
53+
],
54+
dtype=jax.numpy.int32, # NOTE: This int32, not int64
55+
)
56+
57+
X = jax.random.uniform(key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
58+
Y = jax.random.uniform(key, shape=(nonzero_ct, Y_ir.dim),
59+
minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
60+
W = jax.random.uniform(key, shape=(nonzero_ct, problem.weight_numel),
61+
minval=0.0, maxval=1.0, dtype=jax.numpy.float32)
62+
63+
# Reuse problem from earlier
64+
tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False)
65+
Z = tp_conv.forward(
66+
X, Y, W, edge_index[0], edge_index[1]
67+
)
68+
print(jax.numpy.linalg.norm(Z))
69+
```
4070

4171
📣 📣 OpenEquivariance was accepted to the 2025 SIAM Conference on Applied and
4272
Computational Discrete Algorithms (Proceedings Track)! Catch the talk in

tests/examples_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,5 +131,3 @@ def test_tutorial_jax(with_jax):
131131
X, Y, W, edge_index[0], edge_index[1]
132132
)
133133
print(jax.numpy.linalg.norm(Z))
134-
135-

0 commit comments

Comments
 (0)