Skip to content

Commit 745e4e0

Browse files
committed
Finished the forward call.
1 parent 4ea49dc commit 745e4e0

1 file changed

Lines changed: 22 additions & 1 deletion

File tree

openequivariance/openequivariance/impl_jax/TensorProductConv.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@
1313
from openequivariance.benchmark.logging_utils import getLogger
1414
logger = getLogger()
1515

16+
@partial(jax.custom_vjp, nondiff_argnums=(3,4,5,6,7,8,9))
17+
def forward(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs):
18+
forward_call = jax.ffi.ffi_call("conv_forward",
19+
jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype))
20+
return forward_call(X, Y, W, rows, cols, sender_perm, workspace, **attrs)
21+
22+
def forward_with_inputs(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs):
23+
return forward(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs), (X, Y, W, rows, cols, sender_perm, workspace)
24+
1625
class TensorProductConv(LoopUnrollConv):
1726
def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool = False):
1827
dp = extlib.DeviceProp(0)
@@ -50,7 +59,19 @@ def forward(
5059
rows: jax.ndarray,
5160
cols: jax.ndarray,
5261
sender_perm: Optional[jax.ndarray] = None) -> jax.ndarray:
53-
pass
62+
63+
if self.deterministic:
64+
sender_perm = self.dummy_transpose_perm
65+
else:
66+
assert sender_perm is not None, "Must provide sender_perm for non-deterministic convolutions."
67+
68+
return forward(
69+
X, Y, W,
70+
rows, cols, sender_perm,
71+
self.workspace,
72+
self.L3_dim,
73+
self.config.irrep_dtype,
74+
self.attrs)
5475

5576
if __name__=="__main__":
5677
X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e")

0 commit comments

Comments
 (0)