Skip to content

Commit 673b5ee

Browse files
committed
Double backward pass seems to work.
1 parent 63ed1c0 commit 673b5ee

1 file changed

Lines changed: 17 additions & 9 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,22 +41,21 @@ def backward(X, Y, W, dZ, irrep_dtype, attrs):
4141
def backward_with_inputs(X, Y, W, dZ, irrep_dtype, attrs):
4242
return backward(X, Y, W, dZ, irrep_dtype, attrs), (X, Y, W, dZ)
4343

44-
def double_backward(irrep_dtype, attrs, inputs, ddX, ddY, ddW):
44+
def double_backward(irrep_dtype, attrs, inputs, derivatives):
4545
double_backward_call = jax.ffi.ffi_call("tp_double_backward",
4646
(
4747
jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype),
4848
jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype),
4949
jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype),
5050
jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype),
5151
))
52-
53-
return double_backward_call(*inputs, ddX, ddY, ddW, **attrs)
52+
return double_backward_call(*inputs, *derivatives, **attrs)
5453

5554
def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ):
5655
return backward(inputs[0], inputs[1], inputs[2], dZ, irrep_dtype, attrs)
5756

5857
forward.defvjp(forward_with_inputs, backward_autograd)
59-
backward.defvjp(backward_with_inputs, backward_autograd)
58+
backward.defvjp(backward_with_inputs, double_backward)
6059

6160
class TensorProduct(LoopUnrollTP):
6261
def __init__(self, config):
@@ -89,17 +88,26 @@ def forward(self, X, Y, W):
8988
tensor_product = TensorProduct(problem)
9089
batch_size = 1
9190

92-
# Convert the above to JAX Arrays
9391
X = jax.random.uniform(jax.random.PRNGKey(0), (batch_size, X_ir.dim), dtype=jax.numpy.float32)
9492
Y = jax.random.uniform(jax.random.PRNGKey(1), (batch_size, Y_ir.dim), dtype=jax.numpy.float32)
9593
W = jax.random.uniform(jax.random.PRNGKey(2), (batch_size, tensor_product.weight_numel), dtype=jax.numpy.float32)
96-
9794
Z = tensor_product.forward(X, Y, W)
9895

99-
# Test via jax vjp
100-
96+
# Test forward jax vjp
10197
ctZ = jnp.ones_like(Z)
10298
result = jax.vjp(lambda x, y, w: tensor_product.forward(x, y, w), X, Y, W)[1](ctZ)
10399

104100
print(result)
105-
print("COMPLETE!")
101+
print("COMPLETED FORWARD PASS!")
102+
103+
# Test the double backward pass
104+
ddX = jnp.ones_like(X)
105+
ddY = jnp.ones_like(Y)
106+
ddW = jnp.ones_like(W)
107+
result_double_backward = jax.vjp(
108+
lambda x, y, w: jax.vjp(lambda a, b, c: tensor_product.forward(a, b, c), x, y, w)[1](ctZ),
109+
X, Y, W
110+
)[1]((ddX, ddY, ddW))
111+
112+
print(result_double_backward)
113+
print("COMPLETED DOUBLE BACKWARD PASS!")

0 commit comments

Comments
 (0)