Skip to content

Commit 79522a1

Browse files
committed
Wrote double backward function for JAX.
1 parent 4dc31dc commit 79522a1

2 files changed

Lines changed: 22 additions & 4 deletions

File tree

openequivariance/openequivariance/benchmark/correctness_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,17 @@ def correctness_double_backward(
207207
tensors = []
208208
for _, impl in enumerate([test_implementation, reference_implementation]):
209209
tp = instantiate_implementation(impl, problem)
210+
weights_reordered = tp.reorder_weights_from_e3nn(weights, has_batch_dim=not problem.shared_weights)
210211

211212
if impl == CUETensorProduct and problem.shared_weights:
212-
weights = weights[np.newaxis, :]
213+
weights_reordered = weights_reordered[np.newaxis, :]
213214

214-
in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad)
215+
in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights_reordered, weights_dgrad, in1_dgrad, in2_dgrad)
215216
tensors.append(
216217
( out_dgrad,
217218
in1_grad,
218219
in2_grad,
219-
weights_grad
220+
tp.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not problem.shared_weights)
220221
))
221222

222223
for name, to_check, ground_truth in [

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,21 @@ def backward_cpu(
114114
L1_grad[:] = np.asarray(L1_grad_jax)
115115
L2_grad[:] = np.asarray(L2_grad_jax)
116116
weights_grad[:] = np.asarray(weights_grad_jax)
117-
weights_grad[:] = self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights)
117+
weights_grad[:] = self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights)
118+
119+
120+
def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad):
121+
in1_jax = jax.numpy.asarray(in1)
122+
in2_jax = jax.numpy.asarray(in2)
123+
weights_jax = jax.numpy.asarray(weights)
124+
out_grad_jax = jax.numpy.asarray(out_grad)
125+
in1_dgrad_jax = jax.numpy.asarray(in1_dgrad)
126+
in2_dgrad_jax = jax.numpy.asarray(in2_dgrad)
127+
weights_dgrad_jax = jax.numpy.asarray(weights_dgrad)
128+
129+
in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp(
130+
lambda x, y, w: jax.vjp(lambda a, b, c: self.forward(a, b, c), x, y, w)[1](out_grad_jax),
131+
in1_jax, in2_jax, weights_jax
132+
)[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax))
133+
134+
return in1_grad, in2_grad, weights_grad, out_dgrad

0 commit comments

Comments
 (0)