77import hashlib
88from openequivariance .core .e3nn_lite import TPProblem , Irreps
99from openequivariance .core .LoopUnrollTP import LoopUnrollTP
10+ import jax .numpy as jnp
1011
1112def hash_attributes (attrs ):
1213 m = hashlib .sha256 ()
@@ -26,13 +27,14 @@ def forward(X, Y, W, L3_dim, irrep_dtype, attrs):
2627def forward_with_inputs (X , Y , W , L3_dim , irrep_dtype , attrs ):
2728 return forward (X , Y , W , L3_dim , irrep_dtype , attrs ), (X , Y , W )
2829
29- def backward (attrs , irrep_dtype , L3_dim , inputs , dZ ):
30+ def backward (L3_dim , irrep_dtype , attrs , inputs , dZ ):
3031 backward_call = jax .ffi .ffi_call ("tp_backward" ,
3132 (
3233 jax .ShapeDtypeStruct (inputs [0 ].shape , irrep_dtype ),
3334 jax .ShapeDtypeStruct (inputs [1 ].shape , irrep_dtype ),
3435 jax .ShapeDtypeStruct (inputs [2 ].shape , irrep_dtype ),
3536 ))
37+
3638 return backward_call (* inputs , dZ , ** attrs )
3739
3840forward .defvjp (forward_with_inputs , backward )
@@ -66,13 +68,19 @@ def forward(self, X, Y, W):
6668 shared_weights = False ,
6769 internal_weights = False )
6870 tensor_product = TensorProduct (problem )
69- batch_size = 1000
71+ batch_size = 1
7072
7173 # Convert the above to JAX Arrays
7274 X = jax .random .uniform (jax .random .PRNGKey (0 ), (batch_size , X_ir .dim ), dtype = jax .numpy .float32 )
7375 Y = jax .random .uniform (jax .random .PRNGKey (1 ), (batch_size , Y_ir .dim ), dtype = jax .numpy .float32 )
7476 W = jax .random .uniform (jax .random .PRNGKey (2 ), (batch_size , tensor_product .weight_numel ), dtype = jax .numpy .float32 )
7577
7678 Z = tensor_product .forward (X , Y , W )
79+
80+ # Test via jax vjp
81+
82+ ctZ = jnp .ones_like (Z )
83+ result = jax .vjp (lambda x , y , w : tensor_product .forward (x , y , w ), X , Y , W )[1 ](ctZ )
84+
85+ print (result )
7786 print ("COMPLETE!" )
78- print (Z )
0 commit comments