Skip to content

Commit b80d15c

Browse files
committed
Competed benchmarking.
1 parent f5b7a26 commit b80d15c

3 files changed

Lines changed: 191 additions & 0 deletions

File tree

openequivariance/openequivariance/benchmark/plotting/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from openequivariance.benchmark.plotting.plot_double_backward import (
66
plot_double_backward,
77
)
8+
from openequivariance.benchmark.plotting.plot_layout import plot_layout
89

910
__all__ = [
1011
"plot_uvu",
1112
"plot_uvw",
1213
"plot_roofline",
1314
"plot_convolution",
1415
"plot_double_backward",
16+
"plot_layout",
1517
]
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import pathlib
2+
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
6+
from openequivariance.benchmark.plotting.plotting_utils import (
7+
calculate_tp_per_sec,
8+
grouped_barchart,
9+
load_benchmarks,
10+
set_grid,
11+
)
12+
13+
14+
def _parse_layout_label(label: str):
15+
if label.endswith("[mul_ir]"):
16+
return label[: -len(" [mul_ir]")], "mul_ir"
17+
if label.endswith("[ir_mul]"):
18+
return label[: -len(" [ir_mul]")], "ir_mul"
19+
return label, None
20+
21+
22+
def plot_layout(data_folder):
23+
data_folder = pathlib.Path(data_folder)
24+
benchmarks, _ = load_benchmarks(data_folder)
25+
26+
grouped = {}
27+
dtype_order = []
28+
for benchmark in benchmarks:
29+
dtype = benchmark["benchmark results"]["rep_dtype"]
30+
if dtype not in dtype_order:
31+
dtype_order.append(dtype)
32+
33+
direction = benchmark["direction"]
34+
base_label, layout = _parse_layout_label(benchmark["config_label"])
35+
if layout is None:
36+
continue
37+
38+
grouped.setdefault(dtype, {}).setdefault(direction, {}).setdefault(
39+
base_label, {"mul_ir": 0.0, "ir_mul": 0.0}
40+
)
41+
grouped[dtype][direction][base_label][layout] = calculate_tp_per_sec(benchmark)
42+
43+
def _dtype_sort_key(dtype_name: str) -> int:
44+
if "float32" in dtype_name:
45+
return 0
46+
if "float64" in dtype_name:
47+
return 1
48+
return 2
49+
50+
dtype_order = sorted(dtype_order, key=_dtype_sort_key)
51+
52+
directions = [d for d in ["forward", "backward"] if any(d in grouped[x] for x in grouped)]
53+
if not directions:
54+
raise ValueError("No forward/backward layout benchmark entries found to plot.")
55+
56+
fig = plt.figure(figsize=(7, 7))
57+
gs = fig.add_gridspec(len(directions), max(1, len(dtype_order)))
58+
axs = gs.subplots(sharex="col")
59+
60+
if len(directions) == 1 and len(dtype_order) == 1:
61+
axs = np.array([[axs]])
62+
elif len(directions) == 1:
63+
axs = np.array([axs])
64+
elif len(dtype_order) == 1:
65+
axs = np.array([[ax] for ax in axs])
66+
67+
colormap = {"mul_ir": "#1f77b4", "ir_mul": "#2ca02c"}
68+
69+
for row, direction in enumerate(directions):
70+
for col, dtype in enumerate(dtype_order):
71+
axis = axs[row][col]
72+
source = grouped.get(dtype, {}).get(direction, {})
73+
data = {
74+
label: {
75+
"mul_ir": vals["mul_ir"],
76+
"ir_mul": vals["ir_mul"],
77+
}
78+
for label, vals in source.items()
79+
}
80+
grouped_barchart(
81+
data,
82+
axis,
83+
bar_height_fontsize=0,
84+
colormap=colormap,
85+
group_spacing=6.0,
86+
xticklabel=(row == len(directions) - 1),
87+
)
88+
set_grid(axis)
89+
90+
if row == 0:
91+
axis.set_title(dtype.replace("<class 'numpy.", "").replace("'>", ""))
92+
if col == 0:
93+
axis.set_ylabel(direction.capitalize())
94+
if row < len(directions) - 1:
95+
axis.tick_params(axis="x", labelbottom=False)
96+
97+
fig.supylabel("Throughput (# tensor products / s)", x=0.03, y=0.56)
98+
fig.supxlabel("Problem")
99+
100+
handles, labels = axs[0][0].get_legend_handles_labels()
101+
unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]]
102+
if unique:
103+
axs[0][0].legend(*zip(*unique))
104+
105+
fig.tight_layout(rect=(0.03, 0.03, 1.0, 1.0))
106+
fig.savefig(str(data_folder / "layout_throughput_comparison.pdf"))
107+
108+
print("Layout speedups (ir_mul / mul_ir):")
109+
print("\t".join(["dtype", "direction", "min", "mean", "median", "max"]))
110+
for dtype in dtype_order:
111+
for direction in directions:
112+
ratios = []
113+
for _, values in grouped.get(dtype, {}).get(direction, {}).items():
114+
if values["mul_ir"] > 0:
115+
ratios.append(values["ir_mul"] / values["mul_ir"])
116+
if ratios:
117+
stats = [np.min(ratios), np.mean(ratios), np.median(ratios), np.max(ratios)]
118+
stats_fmt = [f"{val:.3f}" for val in stats]
119+
print("\t".join([dtype, direction] + stats_fmt))

tests/benchmark.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,47 @@ def benchmark_kahan_accuracy(params):
385385
)
386386

387387

388+
def benchmark_layouts(params):
389+
base_problems = mace_problems() + nequip_problems()
390+
directions = params.directions
391+
dtypes = [datatype_map[dtype] for dtype in params.datatypes]
392+
393+
tests = []
394+
for dtype in dtypes:
395+
for base_problem in base_problems:
396+
for layout in ["mul_ir", "ir_mul"]:
397+
layout_problem = copy.deepcopy(base_problem)
398+
layout_problem.layout = layout
399+
base_label = layout_problem.label or "layout_problem"
400+
layout_problem.label = f"{base_label} [{layout}]"
401+
layout_problem.irrep_dtype = dtype
402+
layout_problem.weight_dtype = dtype
403+
404+
for direction in directions:
405+
tests.append(
406+
TestDefinition(
407+
TensorProduct,
408+
layout_problem,
409+
direction,
410+
correctness=False,
411+
benchmark=True,
412+
)
413+
)
414+
415+
bench_suite = TestBenchmarkSuite(
416+
num_warmup=100,
417+
num_iter=100,
418+
bench_batch_size=params.batch_size,
419+
prng_seed=11111,
420+
test_name="layouts",
421+
)
422+
423+
data_folder = bench_suite.run(tests, params.output_folder)
424+
425+
if params.plot:
426+
plot({"data_folder": data_folder})
427+
428+
388429
def plot(params):
389430
import openequivariance.benchmark.plotting as plotting
390431

@@ -402,6 +443,8 @@ def plot(params):
402443
plotting.plot_uvu(data_folder)
403444
elif test_name == "uvw":
404445
plotting.plot_uvw(data_folder)
446+
elif test_name == "layouts":
447+
plotting.plot_layout(data_folder)
405448
elif test_name == "roofline":
406449
plotting.plot_roofline(data_folder)
407450
elif test_name == "convolution":
@@ -532,6 +575,33 @@ def plot(params):
532575
parser_uvw.add_argument("--plot", action="store_true", help="Plot the results.")
533576
parser_uvw.set_defaults(func=benchmark_uvw)
534577

578+
parser_layouts = subparsers.add_parser(
579+
"layouts", help="Run benchmark comparing mul_ir vs ir_mul layouts"
580+
)
581+
parser_layouts.add_argument(
582+
"--batch_size", "-b", type=int, default=50000, help="Batch size for benchmark"
583+
)
584+
parser_layouts.add_argument(
585+
"--directions",
586+
"-d",
587+
type=str,
588+
nargs="+",
589+
default=["forward", "backward"],
590+
help="Directions to benchmark",
591+
choices=["forward", "backward"],
592+
)
593+
parser_layouts.add_argument(
594+
"--datatypes",
595+
"-t",
596+
type=str,
597+
nargs="+",
598+
default=["float32", "float64"],
599+
help="Data types to benchmark",
600+
choices=["float32", "float64"],
601+
)
602+
parser_layouts.add_argument("--plot", action="store_true", help="Plot the results.")
603+
parser_layouts.set_defaults(func=benchmark_layouts)
604+
535605
parser_double_bwd = subparsers.add_parser(
536606
"double_backward", help="Run the higher derivative kernel benchmark"
537607
)

0 commit comments

Comments
 (0)