Skip to content

Commit e140c07

Browse files
committed
Added the mixins.
1 parent 71ca862 commit e140c07

4 files changed

Lines changed: 45 additions & 6 deletions

File tree

openequivariance/openequivariance/impl_torch/E3NNConv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
scatter_add_wrapper,
66
)
77
from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct
8+
from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv
89

9-
10-
class E3NNConv(ConvolutionBase):
10+
class E3NNConv(ConvolutionBase, NumpyDoubleBackwardMixinConv):
1111
def __init__(self, config, *, idx_dtype=np.int64, torch_op=True):
1212
assert torch_op
1313
super().__init__(config, idx_dtype=idx_dtype, torch_op=torch_op)

openequivariance/openequivariance/impl_torch/NPDoubleBackwardMixin.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,42 @@ def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dg
3434

3535
return a.detach().cpu().numpy(), b.detach().cpu().numpy(), c.detach().cpu().numpy(), d.detach().cpu().numpy()
3636

37+
38+
class NumpyDoubleBackwardMixinConv:
39+
'''
40+
Similar, but for fused graph convolution.
41+
'''
42+
def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph):
43+
assert self.torch_op
44+
45+
in1_torch = torch.tensor(in1).to('cuda').requires_grad_(True)
46+
in2_torch = torch.tensor(in2).to('cuda').requires_grad_(True)
47+
weights_torch = torch.tensor(weights).to('cuda').requires_grad_(True)
48+
out_grad_torch = torch.tensor(out_grad).to('cuda').requires_grad_(True)
49+
in1_dgrad_torch = torch.tensor(in1_dgrad).to('cuda')
50+
in2_dgrad_torch = torch.tensor(in2_dgrad).to('cuda')
51+
weights_dgrad_torch = torch.tensor(weights_dgrad).to('cuda')
52+
53+
torch_rows = torch.tensor(graph.rows, device="cuda")
54+
torch_cols = torch.tensor(graph.cols, device="cuda")
55+
torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda")
56+
57+
out_torch = self.forward(in1_torch, in2_torch, weights_torch, torch_rows, torch_cols, torch_transpose_perm)
58+
59+
in1_grad, in2_grad, weights_grad = torch.autograd.grad(
60+
outputs=out_torch,
61+
inputs=[in1_torch, in2_torch, weights_torch],
62+
grad_outputs=out_grad_torch,
63+
create_graph=True,
64+
retain_graph=True
65+
)
66+
67+
a, b, c, d = torch.autograd.grad(
68+
outputs=[in1_grad, in2_grad, weights_grad],
69+
inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch],
70+
grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch]
71+
)
72+
73+
return a.detach().cpu().numpy(), b.detach().cpu().numpy(), c.detach().cpu().numpy(), d.detach().cpu().numpy()
74+
75+

openequivariance/openequivariance/impl_torch/TensorProduct.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from openequivariance.benchmark.logging_utils import getLogger
88
from openequivariance.impl_torch.utils import reorder_torch
99
from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin
10-
from openequivariance.core.e3nn_lite import Irreps
1110

1211
logger = getLogger()
1312

@@ -348,7 +347,7 @@ def name():
348347
return "LoopUnrollTP"
349348

350349

351-
if extlib.TORCH_COMPILE and __name__ != "__main__":
350+
if extlib.TORCH_COMPILE:
352351
TensorProduct.register_torch_fakes()
353352
TensorProduct.register_autograd()
354353
TensorProduct.register_autocast()

openequivariance/openequivariance/impl_torch/TensorProductConv.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@
2222
from openequivariance.impl_torch.utils import reorder_torch
2323

2424
from openequivariance.benchmark.logging_utils import getLogger
25+
from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv
2526

26-
logger = getLogger()
2727

28+
logger = getLogger()
2829

29-
class TensorProductConv(torch.nn.Module, LoopUnrollConv):
30+
class TensorProductConv(torch.nn.Module, LoopUnrollConv, NumpyDoubleBackwardMixinConv):
3031
r"""
3132
Given a **symmetric, directed** graph :math:`G = (V, E)`, inputs :math:`x_1...x_{|V|}`,
3233
:math:`y_1...y_{|E|}`, and weights :math:`W_1...W_{|E|}`, computes

0 commit comments

Comments
 (0)