|
13 | 13 | from openequivariance.benchmark.logging_utils import getLogger |
14 | 14 | logger = getLogger() |
15 | 15 |
|
| 16 | +@partial(jax.custom_vjp, nondiff_argnums=(3,4,5,6,7,8,9)) |
| 17 | +def forward(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs): |
| 18 | + forward_call = jax.ffi.ffi_call("conv_forward", |
| 19 | + jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype)) |
| 20 | + return forward_call(X, Y, W, rows, cols, sender_perm, workspace, **attrs) |
| 21 | + |
| 22 | +def forward_with_inputs(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs): |
| 23 | + return forward(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs), (X, Y, W, rows, cols, sender_perm, workspace) |
| 24 | + |
16 | 25 | class TensorProductConv(LoopUnrollConv): |
17 | 26 | def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool = False): |
18 | 27 | dp = extlib.DeviceProp(0) |
@@ -50,7 +59,19 @@ def forward( |
50 | 59 | rows: jax.ndarray, |
51 | 60 | cols: jax.ndarray, |
52 | 61 | sender_perm: Optional[jax.ndarray] = None) -> jax.ndarray: |
53 | | - pass |
| 62 | + |
| 63 | + if self.deterministic: |
| 64 | + sender_perm = self.dummy_transpose_perm |
| 65 | + else: |
| 66 | + assert sender_perm is not None, "Must provide sender_perm for non-deterministic convolutions." |
| 67 | + |
| 68 | + return forward( |
| 69 | + X, Y, W, |
| 70 | + rows, cols, sender_perm, |
| 71 | + self.workspace, |
| 72 | + self.L3_dim, |
| 73 | + self.config.irrep_dtype, |
| 74 | + self.attrs) |
54 | 75 |
|
55 | 76 | if __name__=="__main__": |
56 | 77 | X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e") |
|
0 commit comments