1313from openequivariance .benchmark .logging_utils import getLogger
1414logger = getLogger ()
1515
16- # @partial(jax.custom_vjp, nondiff_argnums=(3,4,5,6,7,8,9))
16+ @partial (jax .custom_vjp , nondiff_argnums = (3 ,4 ,5 ,6 ,7 ,8 ,9 ))
1717def forward (X , Y , W , rows , cols , workspace , sender_perm , L3_dim , irrep_dtype , attrs ):
1818 forward_call = jax .ffi .ffi_call ("conv_forward" ,
1919 jax .ShapeDtypeStruct ((X .shape [0 ], L3_dim ), irrep_dtype ))
@@ -22,6 +22,33 @@ def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, at
2222def forward_with_inputs (X , Y , W , rows , cols , workspace , sender_perm , L3_dim , irrep_dtype , attrs ):
2323 return forward (X , Y , W , rows , cols , workspace , sender_perm , L3_dim , irrep_dtype , attrs ), (X , Y , W , rows , cols , sender_perm , workspace )
2424
25+ @partial (jax .custom_vjp , nondiff_argnums = (4 ,5 ,6 ,7 ,8 ,9 ))
26+ def backward (X , Y , W , dZ , rows , cols , workspace , sender_perm , irrep_dtype , attrs ):
27+ backward_call = jax .ffi .ffi_call ("conv_backward" ,
28+ (jax .ShapeDtypeStruct (X .shape , irrep_dtype ),
29+ jax .ShapeDtypeStruct (Y .shape , irrep_dtype ),
30+ jax .ShapeDtypeStruct (W .shape , irrep_dtype )))
31+ return backward_call (X , Y , W , dZ , rows , cols , workspace , sender_perm , ** attrs )
32+
33+ def backward_with_inputs (X , Y , W , dZ , rows , cols , workspace , sender_perm , L3_dim , irrep_dtype , attrs ):
34+ return backward (X , Y , W , dZ , rows , cols , workspace , sender_perm , L3_dim , irrep_dtype , attrs ), (X , Y , W , dZ , rows , cols , sender_perm , workspace )
35+
36+ def double_backward (rows , cols , workspace , sender_perm , irrep_dtype , attrs , inputs , derivatives ):
37+ double_backward_call = jax .ffi .ffi_call ("conv_double_backward" ,
38+ (
39+ jax .ShapeDtypeStruct (inputs [0 ].shape , irrep_dtype ),
40+ jax .ShapeDtypeStruct (inputs [1 ].shape , irrep_dtype ),
41+ jax .ShapeDtypeStruct (inputs [2 ].shape , irrep_dtype ),
42+ jax .ShapeDtypeStruct (inputs [3 ].shape , irrep_dtype ),
43+ ))
44+ return double_backward_call (* inputs , * derivatives , rows , cols , workspace , sender_perm , ** attrs )
45+
46+ def backward_autograd (rows , cols , workspace , sender_perm , L3_dim , irrep_dtype , attrs , inputs , dZ ):
47+ return backward (inputs [0 ], inputs [1 ], inputs [2 ], dZ , rows , cols , workspace , sender_perm , irrep_dtype , attrs )
48+
49+ forward .defvjp (forward_with_inputs , backward_autograd )
50+ backward .defvjp (backward_with_inputs , double_backward )
51+
2552class TensorProductConv (LoopUnrollConv ):
2653 def __init__ (self , config : TPProblem , deterministic : bool = False , kahan : bool = False ):
2754 dp = extlib .DeviceProp (0 )
@@ -50,7 +77,6 @@ def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool =
5077 logger .info (f"Convolution requires { self .workspace_size // (2 ** 20 )} MB of workspace." )
5178 self .dummy_transpose_perm = jnp .zeros ((1 ,), dtype = jnp .int32 )
5279
53-
5480 def forward (
5581 self ,
5682 X : jax .numpy .ndarray ,
0 commit comments