Skip to content

Commit 71ca862

Browse files
committed
All double backward tests passing.
1 parent 79522a1 commit 71ca862

2 files changed

Lines changed: 4 additions & 3 deletions

File tree

openequivariance/openequivariance/benchmark/correctness_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,12 @@ def correctness_double_backward(
208208
for _, impl in enumerate([test_implementation, reference_implementation]):
209209
tp = instantiate_implementation(impl, problem)
210210
weights_reordered = tp.reorder_weights_from_e3nn(weights, has_batch_dim=not problem.shared_weights)
211+
weights_dgrad_reordered = tp.reorder_weights_from_e3nn(weights_dgrad, has_batch_dim=not problem.shared_weights)
211212

212213
if impl == CUETensorProduct and problem.shared_weights:
213214
weights_reordered = weights_reordered[np.newaxis, :]
214215

215-
in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights_reordered, weights_dgrad, in1_dgrad, in2_dgrad)
216+
in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights_reordered, weights_dgrad_reordered, in1_dgrad, in2_dgrad)
216217
tensors.append(
217218
( out_dgrad,
218219
in1_grad,

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dg
127127
weights_dgrad_jax = jax.numpy.asarray(weights_dgrad)
128128

129129
in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp(
130-
lambda x, y, w: jax.vjp(lambda a, b, c: self.forward(a, b, c), x, y, w)[1](out_grad_jax),
131-
in1_jax, in2_jax, weights_jax
130+
lambda x, y, w, o: jax.vjp(lambda a, b, c: self.forward(a, b, c), x, y, w)[1](o),
131+
in1_jax, in2_jax, weights_jax, out_grad_jax
132132
)[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax))
133133

134134
return in1_grad, in2_grad, weights_grad, out_dgrad

0 commit comments

Comments
 (0)