11import numpy as np
22
33import jax
4+
5+ from functools import partial
46from openequivariance .impl_jax import extlib
57import hashlib
68from openequivariance .core .e3nn_lite import TPProblem , Irreps
@@ -15,13 +17,25 @@ def hash_attributes(attrs):
1517 hash = int (m .hexdigest ()[:16 ], 16 ) >> 1
1618 attrs ["hash" ] = hash
1719
18-
20+ @ partial ( jax . custom_vjp , nondiff_argnums = ( 3 , 4 , 5 ))
1921def forward (X , Y , W , L3_dim , irrep_dtype , attrs ):
2022 forward_call = jax .ffi .ffi_call ("tp_forward" ,
2123 jax .ShapeDtypeStruct ((X .shape [0 ], L3_dim ), irrep_dtype ))
2224 return forward_call (X , Y , W , ** attrs )
2325
24- #def backward()
26+ def forward_with_inputs (X , Y , W , L3_dim , irrep_dtype , attrs ):
27+ return forward (X , Y , W , L3_dim , irrep_dtype , attrs ), (X , Y , W )
28+
29+ def backward (attrs , irrep_dtype , L3_dim , inputs , dZ ):
30+ backward_call = jax .ffi .ffi_call ("tp_backward" ,
31+ (
32+ jax .ShapeDtypeStruct (inputs [0 ].shape , irrep_dtype ),
33+ jax .ShapeDtypeStruct (inputs [1 ].shape , irrep_dtype ),
34+ jax .ShapeDtypeStruct (inputs [2 ].shape , irrep_dtype ),
35+ ))
36+ return backward_call (* inputs , dZ , ** attrs )
37+
38+ forward .defvjp (forward_with_inputs , backward )
2539
2640class TensorProduct (LoopUnrollTP ):
2741 def __init__ (self , config ):
0 commit comments