@@ -15,6 +15,14 @@ def hash_attributes(attrs):
1515 hash = int (m .hexdigest ()[:16 ], 16 ) >> 1
1616 attrs ["hash" ] = hash
1717
18+
19+ def forward (X , Y , W , L3_dim , irrep_dtype , attrs ):
20+ forward_call = jax .ffi .ffi_call ("tp_forward" ,
21+ jax .ShapeDtypeStruct ((X .shape [0 ], L3_dim ), irrep_dtype ))
22+ return forward_call (X , Y , W , ** attrs )
23+
24+ #def backward()
25+
1826class TensorProduct (LoopUnrollTP ):
1927 def __init__ (self , config ):
2028 dp = extlib .DeviceProp (0 )
@@ -33,10 +41,7 @@ def __init__(self, config):
3341 self .L3_dim = self .config .irreps_out .dim
3442
3543 def forward (self , X , Y , W ):
36- forward_call = jax .ffi .ffi_call ("tp_forward" ,
37- jax .ShapeDtypeStruct ((X .shape [0 ], self .L3_dim ), self .config .irrep_dtype ))
38- return forward_call (X , Y , W , ** self .attrs )
39-
44+ return forward (X , Y , W , self .L3_dim , self .config .irrep_dtype , self .attrs )
4045
4146if __name__ == "__main__" :
4247 tp_problem = None
@@ -47,11 +52,7 @@ def forward(self, X, Y, W):
4752 shared_weights = False ,
4853 internal_weights = False )
4954 tensor_product = TensorProduct (problem )
50-
5155 batch_size = 1000
52- #X = torch.rand(batch_size, X_ir.dim, device='cuda', generator=gen)
53- #Y = torch.rand(batch_size, Y_ir.dim, device='cuda', generator=gen)
54- #W = torch.rand(batch_size, tp_e3nn.weight_numel, device='cuda', generator=gen)
5556
5657 # Convert the above to JAX Arrays
5758 X = jax .random .uniform (jax .random .PRNGKey (0 ), (batch_size , X_ir .dim ), dtype = jax .numpy .float32 )
0 commit comments