1- import numpy as np
2-
31import jax
4-
52from functools import partial
63from openequivariance .impl_jax import extlib
7- import hashlib
8- from openequivariance .core .e3nn_lite import TPProblem , Irreps
4+ from openequivariance .core .e3nn_lite import TPProblem
95from openequivariance .core .LoopUnrollTP import LoopUnrollTP
106from openequivariance .core .utils import hash_attributes
11- import jax .numpy as jnp
127
13- @partial (jax .custom_vjp , nondiff_argnums = (3 ,4 ,5 ))
8+
9+ @partial (jax .custom_vjp , nondiff_argnums = (3 , 4 , 5 ))
1410def forward (X , Y , W , L3_dim , irrep_dtype , attrs ):
15- forward_call = jax .ffi .ffi_call ("tp_forward" ,
16- jax .ShapeDtypeStruct ((X .shape [0 ], L3_dim ), irrep_dtype ))
11+ forward_call = jax .ffi .ffi_call (
12+ "tp_forward" , jax .ShapeDtypeStruct ((X .shape [0 ], L3_dim ), irrep_dtype )
13+ )
1714 return forward_call (X , Y , W , ** attrs )
1815
16+
1917def forward_with_inputs (X , Y , W , L3_dim , irrep_dtype , attrs ):
2018 return forward (X , Y , W , L3_dim , irrep_dtype , attrs ), (X , Y , W )
2119
22- @partial (jax .custom_vjp , nondiff_argnums = (4 ,5 ))
20+
21+ @partial (jax .custom_vjp , nondiff_argnums = (4 , 5 ))
2322def backward (X , Y , W , dZ , irrep_dtype , attrs ):
24- backward_call = jax .ffi .ffi_call ("tp_backward" ,
23+ backward_call = jax .ffi .ffi_call (
24+ "tp_backward" ,
2525 (
2626 jax .ShapeDtypeStruct (X .shape , irrep_dtype ),
2727 jax .ShapeDtypeStruct (Y .shape , irrep_dtype ),
2828 jax .ShapeDtypeStruct (W .shape , irrep_dtype ),
29- ))
29+ ),
30+ )
3031
3132 return backward_call (X , Y , W , dZ , ** attrs )
3233
34+
3335def backward_with_inputs (X , Y , W , dZ , irrep_dtype , attrs ):
3436 return backward (X , Y , W , dZ , irrep_dtype , attrs ), (X , Y , W , dZ )
3537
38+
3639def double_backward (irrep_dtype , attrs , inputs , derivatives ):
37- double_backward_call = jax .ffi .ffi_call ("tp_double_backward" ,
40+ double_backward_call = jax .ffi .ffi_call (
41+ "tp_double_backward" ,
3842 (
3943 jax .ShapeDtypeStruct (inputs [0 ].shape , irrep_dtype ),
4044 jax .ShapeDtypeStruct (inputs [1 ].shape , irrep_dtype ),
4145 jax .ShapeDtypeStruct (inputs [2 ].shape , irrep_dtype ),
4246 jax .ShapeDtypeStruct (inputs [3 ].shape , irrep_dtype ),
43- ))
47+ ),
48+ )
4449 return double_backward_call (* inputs , * derivatives , ** attrs )
4550
51+
4652def backward_autograd (L3_dim , irrep_dtype , attrs , inputs , dZ ):
47- return backward (inputs [0 ], inputs [1 ], inputs [2 ], dZ , irrep_dtype , attrs )
53+ return backward (inputs [0 ], inputs [1 ], inputs [2 ], dZ , irrep_dtype , attrs )
54+
4855
4956forward .defvjp (forward_with_inputs , backward_autograd )
5057backward .defvjp (backward_with_inputs , double_backward )
5158
59+
5260class TensorProduct (LoopUnrollTP ):
5361 def __init__ (self , config : TPProblem ):
5462 dp = extlib .DeviceProp (0 )
@@ -59,18 +67,17 @@ def __init__(self, config: TPProblem):
5967 "forward_config" : vars (self .forward_schedule .launch_config ),
6068 "backward_config" : vars (self .backward_schedule .launch_config ),
6169 "double_backward_config" : vars (self .double_backward_schedule .launch_config ),
62- "kernel_prop" : self .kernelProp
70+ "kernel_prop" : self .kernelProp ,
6371 }
6472 hash_attributes (self .attrs )
65-
73+
6674 self .weight_numel = config .weight_numel
6775 self .L3_dim = self .config .irreps_out .dim
6876
6977 def forward (self , X : jax .ndarray , Y : jax .ndarray , W : jax .ndarray ) -> jax .ndarray :
7078 return forward (X , Y , W , self .L3_dim , self .config .irrep_dtype , self .attrs )
7179
72- def __call__ (self ,
73- X : jax .numpy .ndarray ,
74- Y : jax .numpy .ndarray ,
75- W : jax .numpy .ndarray ) -> jax .numpy .ndarray :
76- return self .forward (X , Y , W )
80+ def __call__ (
81+ self , X : jax .numpy .ndarray , Y : jax .numpy .ndarray , W : jax .numpy .ndarray
82+ ) -> jax .numpy .ndarray :
83+ return self .forward (X , Y , W )
0 commit comments