Skip to content

Commit 7f4ac06

Browse files
committed
Backward test is passing.
1 parent fa42654 commit 7f4ac06

1 file changed

Lines changed: 21 additions & 8 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,29 @@ def __call__(
8383
) -> jax.numpy.ndarray:
8484
return self.forward(X, Y, W)
8585

86-
def forward_cpu(
87-
self,
88-
L1_in: np.ndarray,
89-
L2_in: np.ndarray,
90-
L3_out: np.ndarray,
91-
weights: np.ndarray,
92-
) -> None:
86+
def forward_cpu(self, L1_in, L2_in, L3_out, weights) -> None:
9387
result = self.forward(
9488
jax.numpy.asarray(L1_in),
9589
jax.numpy.asarray(L2_in),
9690
jax.numpy.asarray(weights),
9791
)
98-
L3_out[:] = np.asarray(result)
92+
L3_out[:] = np.asarray(result)
93+
94+
def backward_cpu(
95+
self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad
96+
) -> None:
97+
backward_fn = jax.vjp(
98+
lambda X, Y, W: self.forward(X, Y, W),
99+
jax.numpy.asarray(L1_in),
100+
jax.numpy.asarray(L2_in),
101+
jax.numpy.asarray(weights),
102+
)[1]
103+
L1_grad_jax, L2_grad_jax, weights_grad_jax = backward_fn(
104+
jax.numpy.asarray(L3_grad)
105+
)
106+
L1_grad[:] = np.asarray(L1_grad_jax)
107+
L2_grad[:] = np.asarray(L2_grad_jax)
108+
weights_grad[:] = np.asarray(weights_grad_jax)
109+
110+
111+

0 commit comments

Comments
 (0)