@@ -37,6 +37,36 @@ pip install openequivariance[jax]
3737pip install openequivariance_extjax --no-build-isolation
3838```
3939
40+ ``` python
41+ os.environ[" OEQ_NOTORCH" ] = " 1"
42+ import openequivariance as oeq
43+ import jax
44+
45+ seed = 42
46+ key = jax.random.PRNGKey(seed)
47+
48+ node_ct, nonzero_ct = 3 , 4
49+ edge_index = jax.numpy.array(
50+ [
51+ [0 , 1 , 1 , 2 ],
52+ [1 , 0 , 2 , 1 ],
53+ ],
54+ dtype = jax.numpy.int32, # NOTE : This int32, not int64
55+ )
56+
57+ X = jax.random.uniform(key, shape = (node_ct, X_ir.dim), minval = 0.0 , maxval = 1.0 , dtype = jax.numpy.float32)
58+ Y = jax.random.uniform(key, shape = (nonzero_ct, Y_ir.dim),
59+ minval = 0.0 , maxval = 1.0 , dtype = jax.numpy.float32)
60+ W = jax.random.uniform(key, shape = (nonzero_ct, problem.weight_numel),
61+ minval = 0.0 , maxval = 1.0 , dtype = jax.numpy.float32)
62+
63+ # Reuse problem from earlier
64+ tp_conv = oeq.jax.TensorProductConv(problem, deterministic = False )
65+ Z = tp_conv.forward(
66+ X, Y, W, edge_index[0 ], edge_index[1 ]
67+ )
68+ print (jax.numpy.linalg.norm(Z))
69+ ```
4070
4171📣 📣 OpenEquivariance was accepted to the 2025 SIAM Conference on Applied and
4272Computational Discrete Algorithms (Proceedings Track)! Catch the talk in
0 commit comments