11import numpy as np
22from functools import partial
3+ from typing import Optional
34from openequivariance .impl_jax import extlib
45
56from openequivariance .core .e3nn_lite import TPProblem , Irreps
1314logger = getLogger ()
1415
1516class TensorProductConv (LoopUnrollConv ):
16- def __init__ (self , config , deterministic = False , kahan = False ):
17+ def __init__ (self , config : TPProblem , deterministic : bool = False , kahan : bool = False ):
1718 dp = extlib .DeviceProp (0 )
1819 super ().__init__ (
19- self ,
20- config ,
20+ config ,
2121 dp , extlib .postprocess_kernel ,
22- idx_dtype = np .int64 ,
22+ idx_dtype = np .int32 , # Note: this is distinct from PyTorch
2323 torch_op = False ,
2424 deterministic = deterministic ,
2525 kahan = kahan
@@ -30,7 +30,7 @@ def __init__(self, config, deterministic=False, kahan=False):
3030 "forward_config" : vars (self .forward_schedule .launch_config ),
3131 "backward_config" : vars (self .backward_schedule .launch_config ),
3232 "double_backward_config" : vars (self .double_backward_schedule .launch_config ),
33- "kernel_prop" : self .kernelProp
33+ "kernel_prop" : self .kernel_prop
3434 }
3535 hash_attributes (self .attrs )
3636
@@ -39,8 +39,18 @@ def __init__(self, config, deterministic=False, kahan=False):
3939
4040 self .workspace = jnp .zeros ((self .workspace_size ,), dtype = jnp .uint8 )
4141 logger .info (f"Convolution requires { self .workspace_size // (2 ** 20 )} MB of workspace." )
42- self .dummy_transpose_perm = jnp .zeros ((1 ,), dtype = jnp .int64 )
42+ self .dummy_transpose_perm = jnp .zeros ((1 ,), dtype = jnp .int32 )
43+
4344
45+ def forward (
46+ self ,
47+ X : jax .ndarray ,
48+ Y : jax .ndarray ,
49+ W : jax .ndarray ,
50+ rows : jax .ndarray ,
51+ cols : jax .ndarray ,
52+ sender_perm : Optional [jax .ndarray ] = None ) -> jax .ndarray :
53+ pass
4454
4555if __name__ == "__main__" :
4656 X_ir , Y_ir , Z_ir = Irreps ("1x2e" ), Irreps ("1x3e" ), Irreps ("1x2e" )
0 commit comments