@@ -25,19 +25,38 @@ def forward(X, Y, W, L3_dim, irrep_dtype, attrs):
2525 return forward_call (X , Y , W , ** attrs )
2626
2727def forward_with_inputs (X , Y , W , L3_dim , irrep_dtype , attrs ):
28- return forward (X , Y , W , L3_dim , irrep_dtype , attrs ), (X , Y , W )
28+ return forward (X , Y , W , L3_dim , irrep_dtype , attrs ), (X , Y , W )
2929
30- def backward (L3_dim , irrep_dtype , attrs , inputs , dZ ):
30+ @partial (jax .custom_vjp , nondiff_argnums = (4 ,5 ))
31+ def backward (X , Y , W , dZ , irrep_dtype , attrs ):
3132 backward_call = jax .ffi .ffi_call ("tp_backward" ,
33+ (
34+ jax .ShapeDtypeStruct (X .shape , irrep_dtype ),
35+ jax .ShapeDtypeStruct (Y .shape , irrep_dtype ),
36+ jax .ShapeDtypeStruct (W .shape , irrep_dtype ),
37+ ))
38+
39+ return backward_call (X , Y , W , dZ , ** attrs )
40+
41+ def backward_with_inputs (X , Y , W , dZ , irrep_dtype , attrs ):
42+ return backward (X , Y , W , dZ , irrep_dtype , attrs ), (X , Y , W , dZ )
43+
44+ def double_backward (irrep_dtype , attrs , inputs , ddX , ddY , ddW ):
45+ double_backward_call = jax .ffi .ffi_call ("tp_double_backward" ,
3246 (
3347 jax .ShapeDtypeStruct (inputs [0 ].shape , irrep_dtype ),
3448 jax .ShapeDtypeStruct (inputs [1 ].shape , irrep_dtype ),
3549 jax .ShapeDtypeStruct (inputs [2 ].shape , irrep_dtype ),
50+ jax .ShapeDtypeStruct (inputs [3 ].shape , irrep_dtype ),
3651 ))
3752
38- return backward_call (* inputs , dZ , ** attrs )
53+ return double_backward_call (* inputs , ddX , ddY , ddW , ** attrs )
54+
55+ def backward_autograd (L3_dim , irrep_dtype , attrs , inputs , dZ ):
56+ return backward (inputs [0 ], inputs [1 ], inputs [2 ], dZ , irrep_dtype , attrs )
3957
40- forward .defvjp (forward_with_inputs , backward )
58+ forward .defvjp (forward_with_inputs , backward_autograd )
59+ backward .defvjp (backward_with_inputs , backward_autograd )
4160
4261class TensorProduct (LoopUnrollTP ):
4362 def __init__ (self , config ):
0 commit comments