Skip to content

Commit 73aa9b6

Browse files
Changes for Camera-Ready Version of Paper (#97)
* UVU and UVW plots updated. * Benchmarked against fused cuE kernel for convolution test. * Added double backward benchmark and plotting. * Modified the double backward plots. * Minor changes to double backward x-labels. * Updated the README and citation. * Proceeding with final MACE benchmrking. * Ready to wrap up camera-ready. * Updated double backward plot. * Forced unsafe atomic add on AMD HIP to boost performance. * Updated message about camera-ready copy.
1 parent 496f81d commit 73aa9b6

14 files changed

Lines changed: 281 additions & 36 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ scratch.txt
3333
triton_autotuning
3434
paper_benchmarks
3535
paper_benchmarks_v2
36+
paper_benchmarks_v3
3637
openequivariance/extlib/*.so
3738

3839
get_node.sh

README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ We currently support NVIDIA GPUs and just added beta support on AMD GPUs for
2929
all tensor products! See [the coverage table](#tensor-products-we-accelerate) for more
3030
details.
3131

32-
**Warning**: This is an early release, bug reports are welcome.
32+
📣 📣 OpenEquivariance was accepted to the 2025 SIAM Conference on Applied and
33+
Computational Discrete Algorithms (Proceedings Track)! Catch the talk in
34+
Montréal and check out the [camera-ready copy on Arxiv](https://arxiv.org/abs/2501.13986) (available May 12, 2025).
3335

3436
## Show me some examples
3537
Here's a CG tensor product implemented by e3nn:
@@ -279,14 +281,12 @@ If you have a use case for any of the unsupported features above, let us know.
279281
If you find this code useful, please cite our paper:
280282

281283
```bibtex
282-
@misc{openequivariance,
283-
title={An Efficient Sparse Kernel Generator for O(3)-Equivariant Deep Networks},
284-
author={Vivek Bharadwaj and Austin Glover and Aydin Buluc and James Demmel},
285-
year={2025},
286-
eprint={2501.13986},
287-
archivePrefix={arXiv},
288-
primaryClass={cs.LG},
289-
url={https://arxiv.org/abs/2501.13986},
284+
@inbook{openequivariance,
285+
author={Vivek Bharadwaj and Austin Glover and Aydin Buluc and James Demmel},
286+
title={An Efficient Sparse Kernel Generator for O(3)-Equivariant Deep Networks},
287+
booktitle = {SIAM Conference on Applied and Computational Discrete Algorithms (ACDA25)},
288+
chapter = {},
289+
url={https://arxiv.org/abs/2501.13986}
290290
}
291291
```
292292

openequivariance/benchmark/benchmark_routines/paper_benchmark_uvw.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22
import numpy as np
33

44
from openequivariance.benchmark.logging_utils import getLogger
5-
from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProductCompiledCUDAGraphs
5+
from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct, E3NNTensorProductCompiledCUDAGraphs
66
from openequivariance.implementations.CUETensorProduct import CUETensorProduct
77
from openequivariance.implementations.TensorProduct import TensorProduct
88
from openequivariance.benchmark.TestBenchmarkSuite import TestBenchmarkSuite, TestDefinition, Direction
99
from openequivariance.benchmark.tpp_creation_utils import FullyConnectedTPProblem
1010
from openequivariance.benchmark.benchmark_configs import e3nn_torch_tetris_polynomial, diffdock_configs
1111

1212
logger = getLogger()
13+
import torch
14+
from torch._functorch import config
1315

16+
@config.patch("donated_buffer", False)
1417
def run_paper_uvw_benchmark(params) -> pathlib.Path:
1518
FCTPP = FullyConnectedTPProblem
1619

@@ -27,16 +30,15 @@ def run_paper_uvw_benchmark(params) -> pathlib.Path:
2730
problems += float64_problems
2831

2932
implementations = [
30-
#E3NNTensorProductCompiledCUDAGraphs,
31-
#CUETensorProduct,
33+
E3NNTensorProductCompiledCUDAGraphs,
34+
CUETensorProduct,
3235
TensorProduct]
3336

34-
tests = [TestDefinition(implementation, problem, direction, correctness=True, benchmark=True)
37+
tests = [TestDefinition(implementation, problem, direction, correctness=False, benchmark=True)
3538
for problem, direction, implementation
3639
in itertools.product(problems, params.directions, implementations)]
3740

3841
bench_suite = TestBenchmarkSuite(
39-
correctness_threshold = 5e-5,
4042
num_warmup=100,
4143
num_iter=100,
4244
bench_batch_size=params.batch_size,

openequivariance/benchmark/plotting/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
from openequivariance.benchmark.plotting.plot_uvu import plot_uvu
33
from openequivariance.benchmark.plotting.plot_uvw import plot_uvw
44
from openequivariance.benchmark.plotting.plot_roofline import plot_roofline
5-
from openequivariance.benchmark.plotting.plot_convolution import plot_convolution
5+
from openequivariance.benchmark.plotting.plot_convolution import plot_convolution
6+
from openequivariance.benchmark.plotting.plot_double_backward import plot_double_backward

openequivariance/benchmark/plotting/plot_convolution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def plot_convolution(data_folder):
88
benchmarks, metadata = load_benchmarks(data_folder)
99

1010
implementations = ["CUEConvolution",
11+
"CUEConvolutionFused",
1112
"LoopUnrollConvScatterSum",
1213
"LoopUnrollConvAtomic",
1314
"LoopUnrollConvDeterministic"
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
import os, json, pathlib, sys
4+
from openequivariance.benchmark.plotting import *
5+
6+
def plot_double_backward(data_folder):
7+
data_folder = pathlib.Path(data_folder)
8+
benchmarks, metadata = load_benchmarks(data_folder)
9+
10+
configs = metadata["config_labels"]
11+
implementations = ["E3NNTensorProduct", "CUETensorProduct", "LoopUnrollTP"]
12+
13+
def calculate_tp_per_sec(exp):
14+
return exp["benchmark results"]["batch_size"] / (np.mean(exp["benchmark results"]["time_millis"]) * 0.001)
15+
16+
dataf32 = {"double_backward": {}}
17+
for i, desc in enumerate(configs):
18+
for direction in ["double_backward"]:
19+
dataf32[direction][desc] = {}
20+
for impl in implementations:
21+
f32_benches = [b for b in benchmarks if b["benchmark results"]["rep_dtype"] == "<class 'numpy.float32'>"]
22+
exp = filter(f32_benches, {"config_label": desc,
23+
"direction": direction,
24+
"implementation_name": impl
25+
}, match_one=True)
26+
dataf32[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp)
27+
28+
dataf64 = {"double_backward": {}}
29+
for i, desc in enumerate(configs):
30+
for direction in ["double_backward"]:
31+
dataf64[direction][desc] = {}
32+
for impl in implementations:
33+
f64_benches = [b for b in benchmarks if 'float64' in b["benchmark results"]["rep_dtype"]]
34+
35+
exp = filter(f64_benches, {"config_label": desc,
36+
"direction": direction,
37+
"implementation_name": impl
38+
}, match_one=True)
39+
40+
if exp is None:
41+
print(desc)
42+
print(direction)
43+
print(impl)
44+
45+
dataf64[direction][desc][labelmap[impl]] = calculate_tp_per_sec(exp)
46+
47+
fig = plt.figure(figsize=(7, 3))
48+
gs = fig.add_gridspec(1, 2, hspace=0, wspace=0.1)
49+
axs = gs.subplots(sharex='col', sharey='row')
50+
51+
grouped_barchart(dataf32["double_backward"], axs[0], bar_height_fontsize=0, colormap=colormap, group_spacing=6.0)
52+
grouped_barchart(dataf64["double_backward"], axs[1], bar_height_fontsize=0, colormap=colormap, group_spacing=6.0)
53+
54+
for i in range(2):
55+
set_grid(axs[i])
56+
set_grid(axs[i])
57+
58+
axs[0].set_xlabel("float32")
59+
axs[1].set_xlabel("float64")
60+
61+
handles, labels = axs[0].get_legend_handles_labels()
62+
unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
63+
axs[0].legend(*zip(*unique))
64+
65+
for ax in fig.get_axes():
66+
ax.label_outer()
67+
68+
fig.supylabel("2nd Deriv. Throughput\n(# tensor products / s)", y=0.5)
69+
70+
speedup_table = []
71+
for direction in ['double_backward']:
72+
for impl in ['e3nn', 'cuE']:
73+
for dtype_label, dtype_set in [('f32', dataf32), ('f64', dataf64)]:
74+
speedups = [measurement['ours'] / measurement[impl] for _, measurement in dtype_set[direction].items() if impl in measurement]
75+
stats = np.min(speedups), np.mean(speedups), np.median(speedups), np.max(speedups)
76+
stats = [f"{stat:.2f}" for stat in stats]
77+
78+
dir_print = direction
79+
result = [dir_print, impl, dtype_label] + stats
80+
speedup_table.append(result)
81+
82+
print('\t\t'.join(['Direction', 'Base', 'dtype', 'min', 'mean', 'med', 'max']))
83+
for row in speedup_table:
84+
print('\t\t'.join(row))
85+
86+
fig.show()
87+
fig.tight_layout()
88+
fig.savefig(str(data_folder / "double_backward_throughput.pdf"), bbox_inches='tight')

openequivariance/benchmark/plotting/plotting_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,9 @@ def set_size(w, h, ax=None):
296296

297297
labelmap = {"E3NNTensorProduct": "e3nn", "CUETensorProduct": "cuE", "LoopUnrollTP": "ours",
298298
"E3NNTensorProductCompiledCUDAGraphs": "e3nn",
299-
"LoopUnrollConvScatterSum": "fast-scattersum", "CUEConvolution": "cuE-scattersum",
299+
"LoopUnrollConvScatterSum": "fast-scattersum",
300+
"CUEConvolution": "cuE-scattersum",
301+
"CUEConvolutionFused": "cuE-fused",
300302
"LoopUnrollConvDeterministic": "fast-fused-det", "LoopUnrollConvAtomic": "fast-fused-atomic"
301303
}
302304
colormap = {"e3nn": "lightblue", "cuE": "orange", "ours": "g"}
@@ -305,7 +307,8 @@ def set_size(w, h, ax=None):
305307
colormap[key] = colormap["ours"]
306308

307309
colormap["cuE-scattersum"] = colormap["cuE"]
308-
hatchmap = {"fast-fused-det": "oo", "fast-fused-atomic": "//"}
310+
colormap["cuE-fused"] = colormap["cuE"]
311+
hatchmap = {"fast-fused-det": "oo", "fast-fused-atomic": "//", "cuE-fused": "//"}
309312

310313
directions = ["forward", "backward"]
311314
dtypes = ["<class 'numpy.float32'>", "<class 'numpy.float64'>"]

openequivariance/extlib/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
def postprocess(kernel):
4343
kernel = kernel.replace("__syncwarp();", "__threadfence_block();")
4444
kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(")
45+
kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd")
4546
return kernel
4647
postprocess_kernel = postprocess
4748

openequivariance/implementations/CUETensorProduct.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ def analyze_trace(self, trace_file):
171171
event_time_ms = event["dur"] / 1000
172172
total += event_time_ms
173173

174-
if "TensorProductUniform1dKernel" in event["name"]:
174+
if "TensorProductUniform1dKernel" in event["name"] \
175+
or "channelwise_kernel_fwd" in event["name"] \
176+
or "channelwise_kernel_bwd" in event["name"]:
175177
tp_time += event_time_ms
176178

177179
return tp_time
@@ -210,7 +212,7 @@ def benchmark_forward(
210212
with record_function("cue_forward"):
211213
torch_L3_out = self.forward(torch_L1_in, torch_L2_in, torch_weights)
212214

213-
prof.export_chrome_trace(trace_file)
215+
prof.export_chrome_trace(trace_file)
214216
time_millis[i] = self.analyze_trace(trace_file)
215217

216218
return time_millis

openequivariance/implementations/convolution/CUEConv.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import numpy.linalg as la
3+
import itertools
34

45
from openequivariance.implementations.CUETensorProduct import CUETensorProduct
56
from openequivariance.implementations.convolution.ConvolutionBase import *
@@ -25,3 +26,65 @@ def forward(self, L1_in, L2_in, weights, rows, cols):
2526
@staticmethod
2627
def name():
2728
return "CUEConvolution"
29+
30+
class CUEConvFused(ConvolutionBase):
31+
def __init__(self, config, idx_dtype=np.int64, torch_op=True):
32+
super().__init__(config, idx_dtype, torch_op)
33+
34+
global torch
35+
import torch
36+
import e3nn.o3 as o3
37+
38+
np_to_torch_dtype = {
39+
np.float32: torch.float32,
40+
np.float64: torch.float64
41+
}
42+
43+
import cuequivariance as cue
44+
import cuequivariance_torch as cuet
45+
from cuequivariance_torch.primitives.tensor_product import TensorProductUniform4x1dIndexed
46+
47+
class O3_e3nn(cue.O3):
48+
def __mul__( # pylint: disable=no-self-argument
49+
rep1: "O3_e3nn", rep2: "O3_e3nn"
50+
) -> Iterator["O3_e3nn"]:
51+
return [O3_e3nn(l=ir.l, p=ir.p) for ir in cue.O3.__mul__(rep1, rep2)]
52+
53+
@classmethod
54+
def clebsch_gordan(
55+
cls, rep1: "O3_e3nn", rep2: "O3_e3nn", rep3: "O3_e3nn"
56+
) -> np.ndarray:
57+
rep1, rep2, rep3 = cls._from(rep1), cls._from(rep2), cls._from(rep3)
58+
59+
if rep1.p * rep2.p == rep3.p:
60+
return o3.wigner_3j(rep1.l, rep2.l, rep3.l).numpy()[None] * np.sqrt(
61+
rep3.dim
62+
)
63+
return np.zeros((0, rep1.dim, rep2.dim, rep3.dim))
64+
65+
def __lt__( # pylint: disable=no-self-argument
66+
rep1: "O3_e3nn", rep2: "O3_e3nn"
67+
) -> bool:
68+
rep2 = rep1._from(rep2)
69+
return (rep1.l, rep1.p) < (rep2.l, rep2.p)
70+
71+
@classmethod
72+
def iterator(cls) -> Iterator["O3_e3nn"]:
73+
for l in itertools.count(0):
74+
yield O3_e3nn(l=l, p=1 * (-1) ** l)
75+
yield O3_e3nn(l=l, p=-1 * (-1) ** l)
76+
77+
descriptor = (cue.descriptors.channelwise_tensor_product(
78+
cue.Irreps(O3_e3nn, str(config.irreps_in1)),
79+
cue.Irreps(O3_e3nn, str(config.irreps_in2)),
80+
cue.Irreps(O3_e3nn, str(config.irreps_out))
81+
).squeeze_modes().flatten_coefficient_modes())
82+
83+
self.tp = TensorProductUniform4x1dIndexed(descriptor.polynomial.operations[0][1], 'cuda', math_dtype=np_to_torch_dtype[config.irrep_dtype])
84+
85+
def forward(self, L1_in, L2_in, weights, rows, cols):
86+
return self.tp(weights, L1_in, L2_in, None, rows, None, cols, L1_in.shape[0])
87+
88+
@staticmethod
89+
def name():
90+
return "CUEConvolutionFused"

0 commit comments

Comments
 (0)