Skip to content

Commit 63ed1c0

Browse files
committed
Finished the double backward VJP registration.
1 parent ac9b3db commit 63ed1c0

1 file changed

Lines changed: 23 additions & 4 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,38 @@ def forward(X, Y, W, L3_dim, irrep_dtype, attrs):
2525
return forward_call(X, Y, W, **attrs)
2626

2727
def forward_with_inputs(X, Y, W, L3_dim, irrep_dtype, attrs):
28-
return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W)
28+
return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W)
2929

30-
def backward(L3_dim, irrep_dtype, attrs, inputs, dZ):
30+
@partial(jax.custom_vjp, nondiff_argnums=(4,5))
31+
def backward(X, Y, W, dZ, irrep_dtype, attrs):
3132
backward_call = jax.ffi.ffi_call("tp_backward",
33+
(
34+
jax.ShapeDtypeStruct(X.shape, irrep_dtype),
35+
jax.ShapeDtypeStruct(Y.shape, irrep_dtype),
36+
jax.ShapeDtypeStruct(W.shape, irrep_dtype),
37+
))
38+
39+
return backward_call(X, Y, W, dZ, **attrs)
40+
41+
def backward_with_inputs(X, Y, W, dZ, irrep_dtype, attrs):
42+
return backward(X, Y, W, dZ, irrep_dtype, attrs), (X, Y, W, dZ)
43+
44+
def double_backward(irrep_dtype, attrs, inputs, ddX, ddY, ddW):
45+
double_backward_call = jax.ffi.ffi_call("tp_double_backward",
3246
(
3347
jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype),
3448
jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype),
3549
jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype),
50+
jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype),
3651
))
3752

38-
return backward_call(*inputs, dZ, **attrs)
53+
return double_backward_call(*inputs, ddX, ddY, ddW, **attrs)
54+
55+
def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ):
56+
return backward(inputs[0], inputs[1], inputs[2], dZ, irrep_dtype, attrs)
3957

40-
forward.defvjp(forward_with_inputs, backward)
58+
forward.defvjp(forward_with_inputs, backward_autograd)
59+
backward.defvjp(backward_with_inputs, backward_autograd)
4160

4261
class TensorProduct(LoopUnrollTP):
4362
def __init__(self, config):

0 commit comments

Comments
 (0)