Skip to content

Commit 58b7957

Browse files
committed
Ready to modify the double backward correctness function.
1 parent c3f83ea commit 58b7957

3 files changed

Lines changed: 49 additions & 7 deletions

File tree

openequivariance/openequivariance/benchmark/random_buffer_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,13 @@ def get_random_buffers_double_backward(
104104
)
105105
weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype)
106106

107-
weights_grad = np.zeros_like(weights)
108-
in1_grad = np.zeros_like(in1)
109-
in2_grad = np.zeros_like(in2)
110-
out_double_grad = np.zeros_like(out_grad)
107+
weights_grad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype)
108+
in1_grad = np.array(
109+
rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype)
110+
in2_grad = np.array(
111+
rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype)
112+
out_double_grad = np.array(
113+
rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype)
111114

112115
return (
113116
in1,
@@ -176,3 +179,5 @@ def get_random_buffers_backward_conv(
176179
in2_grad = np.zeros_like(in2)
177180

178181
return in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad
182+
183+
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
3+
class NumpyDoubleBackwardMixin:
4+
'''
5+
Adds a Numpy double backward method to any TensorProduct
6+
with the forward pass defined in PyTorch and the relevant
7+
derivatives registered.
8+
'''
9+
def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad):
10+
assert self.torch_op
11+
12+
in1_torch = torch.tensor(in1).to('cuda').requires_grad_(True)
13+
in2_torch = torch.tensor(in2).to('cuda').requires_grad_(True)
14+
weights_torch = torch.tensor(weights).to('cuda').requires_grad_(True)
15+
out_grad_torch = torch.tensor(out_grad).to('cuda').requires_grad_(True)
16+
in1_dgrad_torch = torch.tensor(in1_dgrad).to('cuda')
17+
in2_dgrad_torch = torch.tensor(in2_dgrad).to('cuda')
18+
weights_dgrad_torch = torch.tensor(weights_dgrad).to('cuda')
19+
out_torch = self.forward(in1_torch, in2_torch, weights_torch)
20+
21+
in1_grad, in2_grad, weights_grad = torch.autograd.grad(
22+
outputs=out_torch,
23+
inputs=[in1_torch, in2_torch, weights_torch],
24+
grad_outputs=out_grad_torch,
25+
create_graph=True,
26+
retain_graph=True
27+
)
28+
29+
a, b, c, d = torch.autograd.grad(
30+
outputs=[in1_grad, in2_grad, weights_grad],
31+
inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch],
32+
grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch]
33+
)
34+
35+
return a.detach().cpu().numpy(), b.detach().cpu().numpy(), c.detach().cpu().numpy(), d.detach().cpu().numpy()
36+

openequivariance/openequivariance/impl_torch/TensorProduct.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
from openequivariance.core.utils import torch_to_oeq_dtype
77
from openequivariance.benchmark.logging_utils import getLogger
88
from openequivariance.impl_torch.utils import reorder_torch
9-
9+
from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin
10+
from openequivariance.core.e3nn_lite import Irreps
1011

1112
logger = getLogger()
1213

1314

14-
class TensorProduct(torch.nn.Module, LoopUnrollTP):
15+
class TensorProduct(torch.nn.Module, LoopUnrollTP, NumpyDoubleBackwardMixin):
1516
r"""
1617
Drop-in replacement for ``o3.TensorProduct`` from e3nn. Supports forward,
1718
backward, and double-backward passes using JIT-compiled kernels. Initialization
@@ -347,7 +348,7 @@ def name():
347348
return "LoopUnrollTP"
348349

349350

350-
if extlib.TORCH_COMPILE:
351+
if extlib.TORCH_COMPILE and __name__ != "__main__":
351352
TensorProduct.register_torch_fakes()
352353
TensorProduct.register_autograd()
353354
TensorProduct.register_autocast()

0 commit comments

Comments
 (0)