Skip to content

Commit bfa52a5

Browse files
committed
Skeleton of rule implemented.
1 parent c15f4f7 commit bfa52a5

1 file changed

Lines changed: 16 additions & 2 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22

33
import jax
4+
5+
from functools import partial
46
from openequivariance.impl_jax import extlib
57
import hashlib
68
from openequivariance.core.e3nn_lite import TPProblem, Irreps
@@ -15,13 +17,25 @@ def hash_attributes(attrs):
1517
hash = int(m.hexdigest()[:16], 16) >> 1
1618
attrs["hash"] = hash
1719

18-
20+
@partial(jax.custom_vjp, nondiff_argnums=(3,4,5))
1921
def forward(X, Y, W, L3_dim, irrep_dtype, attrs):
2022
forward_call = jax.ffi.ffi_call("tp_forward",
2123
jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype))
2224
return forward_call(X, Y, W, **attrs)
2325

24-
#def backward()
26+
def forward_with_inputs(X, Y, W, L3_dim, irrep_dtype, attrs):
27+
return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W)
28+
29+
def backward(attrs, irrep_dtype, L3_dim, inputs, dZ):
30+
backward_call = jax.ffi.ffi_call("tp_backward",
31+
(
32+
jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype),
33+
jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype),
34+
jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype),
35+
))
36+
return backward_call(*inputs, dZ, **attrs)
37+
38+
forward.defvjp(forward_with_inputs, backward)
2539

2640
class TensorProduct(LoopUnrollTP):
2741
def __init__(self, config):

0 commit comments

Comments
 (0)