Skip to content

Commit 45e2cd8

Browse files
asgloverAustin Glover
andauthored
Torch scatter add (#143)
* replace custom scatter with torch scatter_add * increase limit for backwards error * move duplicate logic to an external function * change to scatter_add_wrapper --------- Co-authored-by: Austin Glover <austin_glover@berkeley.com>
1 parent 10aa002 commit 45e2cd8

6 files changed

Lines changed: 43 additions & 75 deletions

File tree

openequivariance/implementations/convolution/CUEConv.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
from typing import Iterator
44

55
from openequivariance.implementations.CUETensorProduct import CUETensorProduct
6-
from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase
6+
from openequivariance.implementations.convolution.ConvolutionBase import (
7+
ConvolutionBase,
8+
scatter_add_wrapper,
9+
)
710

811

912
class CUEConv(ConvolutionBase):
@@ -16,15 +19,9 @@ def __init__(self, config, *, idx_dtype=np.int64, torch_op=True):
1619
self.reference_tp = CUETensorProduct(config, torch_op)
1720
self.cue_tp = self.reference_tp.cue_tp
1821

19-
from openequivariance.implementations.convolution.scatter import scatter_sum
20-
21-
self.scatter_sum = scatter_sum
22-
2322
def forward(self, L1_in, L2_in, weights, rows, cols):
24-
tp_outputs = self.cue_tp(L1_in[cols], L2_in, weights)
25-
return self.scatter_sum(
26-
src=tp_outputs, index=rows, dim=0, dim_size=L1_in.shape[0]
27-
)
23+
messages = self.reference_tp(L1_in[cols], L2_in, weights)
24+
return scatter_add_wrapper(messages, rows, L1_in.size(0))
2825

2926
@staticmethod
3027
def name():

openequivariance/implementations/convolution/ConvolutionBase.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,11 @@ def test_correctness_double_backward(
713713
in2_torch = torch.tensor(in2, device="cuda", requires_grad=True)
714714

715715
weights_reordered = np.zeros_like(weights)
716-
if i == 0 and self.reorder_weights_e3nn_to_oeq is not None:
716+
if (
717+
i == 0
718+
and hasattr(self, "reorder_weights_e3nn_to_oeq")
719+
and self.reorder_weights_e3nn_to_oeq is not None
720+
):
717721
self.reorder_weights_e3nn_to_oeq(
718722
weights, weights_reordered, not self.config.shared_weights
719723
)
@@ -750,7 +754,11 @@ def test_correctness_double_backward(
750754
)
751755

752756
weights_grad = weights_torch.grad.detach().cpu().numpy()
753-
if i == 0 and self.reorder_weights_oeq_to_e3nn is not None:
757+
if (
758+
i == 0
759+
and hasattr(self, "reorder_weights_e3nn_to_oeq")
760+
and self.reorder_weights_oeq_to_e3nn is not None
761+
):
754762
weights_grad_copy = weights_grad.copy()
755763
self.reorder_weights_oeq_to_e3nn(
756764
weights_grad_copy, weights_grad, not self.config.shared_weights
@@ -774,3 +782,15 @@ def test_correctness_double_backward(
774782
result[name] = check_similiarity(name, to_check, ground_truth, thresh)
775783

776784
return result
785+
786+
787+
def scatter_add_wrapper(messages, rows, target_dim):
788+
L3_dim = messages.size(1)
789+
idx = rows.unsqueeze(1).expand(-1, L3_dim)
790+
out = messages.new_zeros((target_dim, L3_dim))
791+
return torch.scatter_add(
792+
input=out,
793+
dim=0,
794+
index=idx,
795+
src=messages,
796+
)

openequivariance/implementations/convolution/E3NNConv.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import numpy as np
22

3-
from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase
3+
from openequivariance.implementations.convolution.ConvolutionBase import (
4+
ConvolutionBase,
5+
scatter_add_wrapper,
6+
)
47
from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct
58

69

@@ -34,15 +37,9 @@ def __init__(self, config, *, idx_dtype=np.int64, torch_op=True):
3437
if config.irrep_dtype == np.float64:
3538
torch.set_default_dtype(torch.float32) # Reset to default
3639

37-
from openequivariance.implementations.convolution.scatter import scatter_sum
38-
39-
self.scatter_sum = scatter_sum
40-
4140
def forward(self, L1_in, L2_in, weights, rows, cols):
42-
tp_outputs = self.reference_tp(L1_in[cols], L2_in, weights)
43-
return self.scatter_sum(
44-
src=tp_outputs, index=rows, dim=0, dim_size=L1_in.shape[0]
45-
)
41+
messages = self.reference_tp(L1_in[cols], L2_in, weights)
42+
return scatter_add_wrapper(messages, rows, L1_in.size(0))
4643

4744
@staticmethod
4845
def name():

openequivariance/implementations/convolution/TensorProductConv.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import torch
55

66
from openequivariance import extlib
7-
from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase
7+
from openequivariance.implementations.convolution.ConvolutionBase import (
8+
ConvolutionBase,
9+
scatter_add_wrapper,
10+
)
811
from openequivariance.implementations.convolution.LoopUnrollConv import LoopUnrollConv
912
from openequivariance.implementations.TensorProduct import TensorProduct
1013
from openequivariance import TPProblem
@@ -414,15 +417,12 @@ def __init__(self, config, *, torch_op=True):
414417
super().__init__(config, torch_op=torch_op, deterministic=False)
415418

416419
self.reference_tp = TensorProduct(config, torch_op=torch_op)
417-
from openequivariance.implementations.convolution.scatter import scatter_sum
418-
419-
self.scatter_sum = scatter_sum
420+
self.reorder_weights_e3nn_to_oeq = self.reference_tp.reorder_weights_e3nn_to_oeq
421+
self.reorder_weights_oeq_to_e3nn = self.reference_tp.reorder_weights_oeq_to_e3nn
420422

421423
def forward(self, L1_in, L2_in, weights, rows, cols):
422-
tp_outputs = self.reference_tp(L1_in[cols], L2_in, weights)
423-
return self.scatter_sum(
424-
src=tp_outputs, index=rows, dim=0, dim_size=L1_in.shape[0]
425-
)
424+
messages = self.reference_tp(L1_in[cols], L2_in, weights)
425+
return scatter_add_wrapper(messages, rows, L1_in.size(0))
426426

427427
def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph):
428428
tp_outputs = np.zeros((graph.nnz, self.L3.dim), dtype=L3_out.dtype)

openequivariance/implementations/convolution/scatter.py

Lines changed: 0 additions & 46 deletions
This file was deleted.

tests/conv_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ class TestAtomicSharedWeights(ConvCorrectness):
224224
def thresh(self, direction):
225225
return {
226226
"fwd": 1e-5,
227-
"bwd": 5e-2, # Expect higher errors for shared weights
227+
"bwd": 7.5e-2, # Expect higher errors for shared weights
228228
"double_bwd": 5e-2,
229229
}[direction]
230230

0 commit comments

Comments
 (0)