@@ -41,22 +41,21 @@ def backward(X, Y, W, dZ, irrep_dtype, attrs):
4141def backward_with_inputs (X , Y , W , dZ , irrep_dtype , attrs ):
4242 return backward (X , Y , W , dZ , irrep_dtype , attrs ), (X , Y , W , dZ )
4343
44- def double_backward (irrep_dtype , attrs , inputs , ddX , ddY , ddW ):
44+ def double_backward (irrep_dtype , attrs , inputs , derivatives ):
4545 double_backward_call = jax .ffi .ffi_call ("tp_double_backward" ,
4646 (
4747 jax .ShapeDtypeStruct (inputs [0 ].shape , irrep_dtype ),
4848 jax .ShapeDtypeStruct (inputs [1 ].shape , irrep_dtype ),
4949 jax .ShapeDtypeStruct (inputs [2 ].shape , irrep_dtype ),
5050 jax .ShapeDtypeStruct (inputs [3 ].shape , irrep_dtype ),
5151 ))
52-
53- return double_backward_call (* inputs , ddX , ddY , ddW , ** attrs )
52+ return double_backward_call (* inputs , * derivatives , ** attrs )
5453
5554def backward_autograd (L3_dim , irrep_dtype , attrs , inputs , dZ ):
5655 return backward (inputs [0 ], inputs [1 ], inputs [2 ], dZ , irrep_dtype , attrs )
5756
5857forward .defvjp (forward_with_inputs , backward_autograd )
59- backward .defvjp (backward_with_inputs , backward_autograd )
58+ backward .defvjp (backward_with_inputs , double_backward )
6059
6160class TensorProduct (LoopUnrollTP ):
6261 def __init__ (self , config ):
@@ -89,17 +88,26 @@ def forward(self, X, Y, W):
8988 tensor_product = TensorProduct (problem )
9089 batch_size = 1
9190
92- # Convert the above to JAX Arrays
9391 X = jax .random .uniform (jax .random .PRNGKey (0 ), (batch_size , X_ir .dim ), dtype = jax .numpy .float32 )
9492 Y = jax .random .uniform (jax .random .PRNGKey (1 ), (batch_size , Y_ir .dim ), dtype = jax .numpy .float32 )
9593 W = jax .random .uniform (jax .random .PRNGKey (2 ), (batch_size , tensor_product .weight_numel ), dtype = jax .numpy .float32 )
96-
9794 Z = tensor_product .forward (X , Y , W )
9895
99- # Test via jax vjp
100-
96+ # Test forward jax vjp
10197 ctZ = jnp .ones_like (Z )
10298 result = jax .vjp (lambda x , y , w : tensor_product .forward (x , y , w ), X , Y , W )[1 ](ctZ )
10399
104100 print (result )
105- print ("COMPLETE!" )
101+ print ("COMPLETED FORWARD PASS!" )
102+
103+ # Test the double backward pass
104+ ddX = jnp .ones_like (X )
105+ ddY = jnp .ones_like (Y )
106+ ddW = jnp .ones_like (W )
107+ result_double_backward = jax .vjp (
108+ lambda x , y , w : jax .vjp (lambda a , b , c : tensor_product .forward (a , b , c ), x , y , w )[1 ](ctZ ),
109+ X , Y , W
110+ )[1 ]((ddX , ddY , ddW ))
111+
112+ print (result_double_backward )
113+ print ("COMPLETED DOUBLE BACKWARD PASS!" )
0 commit comments