Skip to content

Commit d1131fa

Browse files
committed
Backward call is working.
1 parent bfa52a5 commit d1131fa

1 file changed

Lines changed: 11 additions & 3 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import hashlib
88
from openequivariance.core.e3nn_lite import TPProblem, Irreps
99
from openequivariance.core.LoopUnrollTP import LoopUnrollTP
10+
import jax.numpy as jnp
1011

1112
def hash_attributes(attrs):
1213
m = hashlib.sha256()
@@ -26,13 +27,14 @@ def forward(X, Y, W, L3_dim, irrep_dtype, attrs):
2627
def forward_with_inputs(X, Y, W, L3_dim, irrep_dtype, attrs):
2728
return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W)
2829

29-
def backward(attrs, irrep_dtype, L3_dim, inputs, dZ):
30+
def backward(L3_dim, irrep_dtype, attrs, inputs, dZ):
3031
backward_call = jax.ffi.ffi_call("tp_backward",
3132
(
3233
jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype),
3334
jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype),
3435
jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype),
3536
))
37+
3638
return backward_call(*inputs, dZ, **attrs)
3739

3840
forward.defvjp(forward_with_inputs, backward)
@@ -66,13 +68,19 @@ def forward(self, X, Y, W):
6668
shared_weights=False,
6769
internal_weights=False)
6870
tensor_product = TensorProduct(problem)
69-
batch_size = 1000
71+
batch_size = 1
7072

7173
# Convert the above to JAX Arrays
7274
X = jax.random.uniform(jax.random.PRNGKey(0), (batch_size, X_ir.dim), dtype=jax.numpy.float32)
7375
Y = jax.random.uniform(jax.random.PRNGKey(1), (batch_size, Y_ir.dim), dtype=jax.numpy.float32)
7476
W = jax.random.uniform(jax.random.PRNGKey(2), (batch_size, tensor_product.weight_numel), dtype=jax.numpy.float32)
7577

7678
Z = tensor_product.forward(X, Y, W)
79+
80+
# Test via jax vjp
81+
82+
ctZ = jnp.ones_like(Z)
83+
result = jax.vjp(lambda x, y, w: tensor_product.forward(x, y, w), X, Y, W)[1](ctZ)
84+
85+
print(result)
7786
print("COMPLETE!")
78-
print(Z)

0 commit comments

Comments
 (0)