Skip to content

Commit c15f4f7

Browse files
committed
Encapsulated the forward call.
1 parent 6790bd0 commit c15f4f7

2 files changed

Lines changed: 10 additions & 9 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ def hash_attributes(attrs):
1515
hash = int(m.hexdigest()[:16], 16) >> 1
1616
attrs["hash"] = hash
1717

18+
19+
def forward(X, Y, W, L3_dim, irrep_dtype, attrs):
20+
forward_call = jax.ffi.ffi_call("tp_forward",
21+
jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype))
22+
return forward_call(X, Y, W, **attrs)
23+
24+
#def backward()
25+
1826
class TensorProduct(LoopUnrollTP):
1927
def __init__(self, config):
2028
dp = extlib.DeviceProp(0)
@@ -33,10 +41,7 @@ def __init__(self, config):
3341
self.L3_dim = self.config.irreps_out.dim
3442

3543
def forward(self, X, Y, W):
36-
forward_call = jax.ffi.ffi_call("tp_forward",
37-
jax.ShapeDtypeStruct((X.shape[0], self.L3_dim), self.config.irrep_dtype))
38-
return forward_call(X, Y, W, **self.attrs)
39-
44+
return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs)
4045

4146
if __name__ == "__main__":
4247
tp_problem = None
@@ -47,11 +52,7 @@ def forward(self, X, Y, W):
4752
shared_weights=False,
4853
internal_weights=False)
4954
tensor_product = TensorProduct(problem)
50-
5155
batch_size = 1000
52-
#X = torch.rand(batch_size, X_ir.dim, device='cuda', generator=gen)
53-
#Y = torch.rand(batch_size, Y_ir.dim, device='cuda', generator=gen)
54-
#W = torch.rand(batch_size, tp_e3nn.weight_numel, device='cuda', generator=gen)
5556

5657
# Convert the above to JAX Arrays
5758
X = jax.random.uniform(jax.random.PRNGKey(0), (batch_size, X_ir.dim), dtype=jax.numpy.float32)

openequivariance_extjax/src/libjax_tp_jit.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ ffi::Error tp_backward_impl(
247247
// Need to zero out W_grad
248248
}
249249

250-
jit_kernel->exec_tensor_product_backward(
250+
jit_kernel->backward(
251251
num_batch,
252252
data_ptr(L1_in),
253253
data_ptr(L1_grad),

0 commit comments

Comments
 (0)