Skip to content

Commit 8a0094a

Browse files
committed
Added double backward CPU function to jax TP conv.
1 parent e140c07 commit 8a0094a

3 files changed

Lines changed: 77 additions & 57 deletions

File tree

openequivariance/openequivariance/benchmark/random_buffer_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,44 @@ def get_random_buffers_backward_conv(
181181
return in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad
182182

183183

184+
def get_random_buffers_double_backward_conv(
185+
tpp: TPProblem, node_count: int, edge_count: int, prng_seed: int
186+
):
187+
rng = np.random.default_rng(prng_seed)
188+
in1 = np.array(
189+
rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype
190+
)
191+
in2 = np.array(
192+
rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype
193+
)
194+
out_grad = np.array(
195+
rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype
196+
)
197+
198+
weights_size = (
199+
tuple([tpp.weight_numel])
200+
if tpp.shared_weights
201+
else tuple([edge_count, tpp.weight_numel])
202+
)
203+
204+
weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype)
205+
weights_grad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype)
206+
in1_grad = np.array(
207+
rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype
208+
)
209+
in2_grad = np.array(
210+
rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype
211+
)
212+
out_double_grad = np.array(
213+
rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype
214+
)
215+
return (
216+
in1,
217+
in2,
218+
out_grad,
219+
weights,
220+
weights_grad,
221+
in1_grad,
222+
in2_grad,
223+
out_double_grad,
224+
)

openequivariance/openequivariance/core/ConvolutionBase.py

Lines changed: 15 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from openequivariance.benchmark.random_buffer_utils import (
44
get_random_buffers_forward_conv,
55
get_random_buffers_backward_conv,
6+
get_random_buffers_double_backward_conv,
67
)
78

89
from openequivariance.benchmark.logging_utils import getLogger, bcolors
@@ -13,7 +14,6 @@
1314

1415
logger = getLogger()
1516

16-
1717
def flops_data_per_tp(config, direction):
1818
"""
1919
Assumes all interactions are "uvu" for now
@@ -549,20 +549,12 @@ def test_correctness_double_backward(
549549
reference_implementation=None,
550550
high_precision_ref=False,
551551
):
552-
global torch
553-
import torch
554-
555-
assert self.torch_op
556-
buffers = get_random_buffers_backward_conv(
557-
self.config, graph.node_count, graph.nnz, prng_seed
558-
)
559-
560-
rng = np.random.default_rng(seed=prng_seed * 2)
561-
dummy_grad_value = rng.standard_normal(1)[0]
552+
buffers = get_random_buffers_double_backward_conv(
553+
self.config, graph.node_count, graph.nnz, prng_seed
554+
)
562555

563556
if reference_implementation is None:
564557
from openequivariance.impl_torch.E3NNConv import E3NNConv
565-
566558
reference_implementation = E3NNConv
567559

568560
reference_problem = self.config
@@ -576,63 +568,30 @@ def test_correctness_double_backward(
576568
result = {"thresh": thresh}
577569
tensors = []
578570
for i, tp in enumerate([self, reference_tp]):
579-
in1, in2, out_grad, weights, _, _, _ = [buf.copy() for buf in buffers]
571+
buffers_copy = [buf.copy() for buf in buffers]
580572

581573
if i == 1 and high_precision_ref:
582-
in1, in2, out_grad, weights, _, _, _ = [
574+
buffers_copy = [
583575
np.array(el, dtype=np.float64) for el in buffers
584576
]
585577

586-
in1_torch = torch.tensor(in1, device="cuda", requires_grad=True)
587-
in2_torch = torch.tensor(in2, device="cuda", requires_grad=True)
578+
in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = buffers_copy
588579

589580
weights_reordered = tp.reorder_weights_from_e3nn(
590581
weights, not self.config.shared_weights
591582
)
592-
593-
weights_torch = torch.tensor(
594-
weights_reordered, device="cuda", requires_grad=True
583+
weights_dgrad_reordered = tp.reorder_weights_from_e3nn(
584+
weights_dgrad, not self.config.shared_weights
595585
)
596586

597-
torch_rows = torch.tensor(graph.rows, device="cuda")
598-
torch_cols = torch.tensor(graph.cols, device="cuda")
599-
torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda")
600-
601-
fwd_args = [in1_torch, in2_torch, weights_torch, torch_rows, torch_cols]
602-
if tp.deterministic:
603-
fwd_args.append(torch_transpose_perm)
604-
605-
out_torch = tp.forward(*fwd_args)
606-
out_grad_torch = torch.tensor(out_grad, device="cuda", requires_grad=True)
607-
608-
in1_grad, in2_grad, w_grad = torch.autograd.grad(
609-
outputs=[out_torch],
610-
inputs=[in1_torch, in2_torch, weights_torch],
611-
grad_outputs=[out_grad_torch],
612-
create_graph=True,
613-
)
614-
615-
dummy = torch.norm(in1_grad) + torch.norm(in2_grad) + torch.norm(w_grad)
616-
dummy_grad = torch.tensor(
617-
float(dummy_grad_value), device="cuda", requires_grad=True
618-
)
619-
dummy.backward(
620-
dummy_grad, inputs=[out_grad_torch, in1_torch, in2_torch, weights_torch]
621-
)
622-
623-
weights_grad = weights_torch.grad.detach().cpu().numpy()
624-
weights_grad = tp.reorder_weights_to_e3nn(
625-
weights_grad, not self.config.shared_weights
626-
)
587+
in1_grad, in2_grad, weights_grad, out_dgrad = self.double_backward_cpu(in1, in2, out_grad, weights_reordered, weights_dgrad_reordered, in1_dgrad, in2_dgrad, graph)
627588

628589
tensors.append(
629-
(
630-
out_grad_torch.grad.detach().cpu().numpy().copy(),
631-
in1_torch.grad.detach().cpu().numpy().copy(),
632-
in2_torch.grad.detach().cpu().numpy().copy(),
633-
weights_grad.copy(),
634-
)
635-
)
590+
( out_dgrad,
591+
in1_grad,
592+
in2_grad,
593+
self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights)
594+
))
636595

637596
for name, to_check, ground_truth in [
638597
("output_grad", tensors[0][0], tensors[1][0]),

openequivariance/openequivariance/impl_jax/TensorProductConv.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,4 +220,24 @@ def backward_cpu(
220220
L1_grad[:] = np.asarray(L1_grad_jax)
221221
L2_grad[:] = np.asarray(L2_grad_jax)
222222
weights_grad[:] = np.asarray(weights_grad_jax)
223-
weights_grad[:] = self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights)
223+
weights_grad[:] = self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights)
224+
225+
def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph):
226+
in1_jax = jax.numpy.asarray(in1)
227+
in2_jax = jax.numpy.asarray(in2)
228+
weights_jax = jax.numpy.asarray(weights)
229+
out_grad_jax = jax.numpy.asarray(out_grad)
230+
in1_dgrad_jax = jax.numpy.asarray(in1_dgrad)
231+
in2_dgrad_jax = jax.numpy.asarray(in2_dgrad)
232+
weights_dgrad_jax = jax.numpy.asarray(weights_dgrad)
233+
234+
rows_jax = jax.numpy.asarray(graph.rows.astype(self.idx_dtype))
235+
cols_jax = jax.numpy.asarray(graph.cols.astype(self.idx_dtype))
236+
sender_perm_jax = jax.numpy.asarray(graph.transpose_perm.astype(self.idx_dtype))
237+
238+
in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp(
239+
lambda x, y, w, o: jax.vjp(lambda a, b, c: self.forward(a, b, c, rows_jax, cols_jax, sender_perm_jax), x, y, w)[1](o),
240+
in1_jax, in2_jax, weights_jax, out_grad_jax
241+
)[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax))
242+
243+
return in1_grad, in2_grad, weights_grad, out_dgrad

0 commit comments

Comments
 (0)