Skip to content

Commit 19f284b

Browse files
committed
Ready to start JAX support.
1 parent 745e4e0 commit 19f284b

1 file changed

Lines changed: 10 additions & 0 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProductConv.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,14 @@ def forward(
8282
internal_weights=False)
8383

8484
conv = TensorProductConv(problem, deterministic=False, kahan=False)
85+
86+
node_ct, nonzero_ct = 3, 4
87+
X = jax.random.uniform(jax.random.PRNGKey(0), (node_ct, X_ir.dim), dtype=jax.numpy.float32)
88+
Y = jax.random.uniform(jax.random.PRNGKey(1), (nonzero_ct, Y_ir.dim), dtype=jax.numpy.float32)
89+
W = jax.random.uniform(jax.random.PRNGKey(2), (nonzero_ct, conv.weight_numel), dtype=jax.numpy.float32)
90+
rows = jnp.array([0, 1, 1, 2], dtype=jnp.int32)
91+
cols = jnp.array([1, 0, 2, 1], dtype=jnp.int32)
92+
Z = conv.forward(X, Y, W, rows, cols)
93+
print("Z:", Z)
94+
8595
print("COMPLETE!")

0 commit comments

Comments
 (0)