Skip to content

Commit d94db28

Browse files
committed
Registered the VJP rules for backward and double-backward.
1 parent ce68f69 commit d94db28

1 file changed

Lines changed: 28 additions & 2 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProductConv.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from openequivariance.benchmark.logging_utils import getLogger
1414
logger = getLogger()
1515

16-
#@partial(jax.custom_vjp, nondiff_argnums=(3,4,5,6,7,8,9))
16+
@partial(jax.custom_vjp, nondiff_argnums=(3,4,5,6,7,8,9))
1717
def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs):
1818
forward_call = jax.ffi.ffi_call("conv_forward",
1919
jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype))
@@ -22,6 +22,33 @@ def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, at
2222
def forward_with_inputs(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs):
2323
return forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs), (X, Y, W, rows, cols, sender_perm, workspace)
2424

25+
@partial(jax.custom_vjp, nondiff_argnums=(4,5,6,7,8,9))
26+
def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs):
27+
backward_call = jax.ffi.ffi_call("conv_backward",
28+
(jax.ShapeDtypeStruct(X.shape, irrep_dtype),
29+
jax.ShapeDtypeStruct(Y.shape, irrep_dtype),
30+
jax.ShapeDtypeStruct(W.shape, irrep_dtype)))
31+
return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, **attrs)
32+
33+
def backward_with_inputs(X, Y, W, dZ, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs):
34+
return backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs), (X, Y, W, dZ, rows, cols, sender_perm, workspace)
35+
36+
def double_backward(rows, cols, workspace, sender_perm, irrep_dtype, attrs, inputs, derivatives):
37+
double_backward_call = jax.ffi.ffi_call("conv_double_backward",
38+
(
39+
jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype),
40+
jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype),
41+
jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype),
42+
jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype),
43+
))
44+
return double_backward_call(*inputs, *derivatives, rows, cols, workspace, sender_perm, **attrs)
45+
46+
def backward_autograd(rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs, inputs, dZ):
47+
return backward(inputs[0], inputs[1], inputs[2], dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs)
48+
49+
forward.defvjp(forward_with_inputs, backward_autograd)
50+
backward.defvjp(backward_with_inputs, double_backward)
51+
2552
class TensorProductConv(LoopUnrollConv):
2653
def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool = False):
2754
dp = extlib.DeviceProp(0)
@@ -50,7 +77,6 @@ def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool =
5077
logger.info(f"Convolution requires {self.workspace_size // (2 ** 20)}MB of workspace.")
5178
self.dummy_transpose_perm = jnp.zeros((1,), dtype=jnp.int32)
5279

53-
5480
def forward(
5581
self,
5682
X: jax.numpy.ndarray,

0 commit comments

Comments
 (0)