|
| 1 | +import pathlib |
| 2 | + |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +from openequivariance.benchmark.plotting.plotting_utils import ( |
| 7 | + calculate_tp_per_sec, |
| 8 | + grouped_barchart, |
| 9 | + load_benchmarks, |
| 10 | + set_grid, |
| 11 | +) |
| 12 | + |
| 13 | + |
| 14 | +def _parse_layout_label(label: str): |
| 15 | + if label.endswith("[mul_ir]"): |
| 16 | + return label[: -len(" [mul_ir]")], "mul_ir" |
| 17 | + if label.endswith("[ir_mul]"): |
| 18 | + return label[: -len(" [ir_mul]")], "ir_mul" |
| 19 | + return label, None |
| 20 | + |
| 21 | + |
| 22 | +def plot_layout(data_folder): |
| 23 | + data_folder = pathlib.Path(data_folder) |
| 24 | + benchmarks, _ = load_benchmarks(data_folder) |
| 25 | + |
| 26 | + grouped = {} |
| 27 | + dtype_order = [] |
| 28 | + for benchmark in benchmarks: |
| 29 | + dtype = benchmark["benchmark results"]["rep_dtype"] |
| 30 | + if dtype not in dtype_order: |
| 31 | + dtype_order.append(dtype) |
| 32 | + |
| 33 | + direction = benchmark["direction"] |
| 34 | + base_label, layout = _parse_layout_label(benchmark["config_label"]) |
| 35 | + if layout is None: |
| 36 | + continue |
| 37 | + |
| 38 | + grouped.setdefault(dtype, {}).setdefault(direction, {}).setdefault( |
| 39 | + base_label, {"mul_ir": 0.0, "ir_mul": 0.0} |
| 40 | + ) |
| 41 | + grouped[dtype][direction][base_label][layout] = calculate_tp_per_sec(benchmark) |
| 42 | + |
| 43 | + def _dtype_sort_key(dtype_name: str) -> int: |
| 44 | + if "float32" in dtype_name: |
| 45 | + return 0 |
| 46 | + if "float64" in dtype_name: |
| 47 | + return 1 |
| 48 | + return 2 |
| 49 | + |
| 50 | + dtype_order = sorted(dtype_order, key=_dtype_sort_key) |
| 51 | + |
| 52 | + directions = [d for d in ["forward", "backward"] if any(d in grouped[x] for x in grouped)] |
| 53 | + if not directions: |
| 54 | + raise ValueError("No forward/backward layout benchmark entries found to plot.") |
| 55 | + |
| 56 | + fig = plt.figure(figsize=(7, 7)) |
| 57 | + gs = fig.add_gridspec(len(directions), max(1, len(dtype_order))) |
| 58 | + axs = gs.subplots(sharex="col") |
| 59 | + |
| 60 | + if len(directions) == 1 and len(dtype_order) == 1: |
| 61 | + axs = np.array([[axs]]) |
| 62 | + elif len(directions) == 1: |
| 63 | + axs = np.array([axs]) |
| 64 | + elif len(dtype_order) == 1: |
| 65 | + axs = np.array([[ax] for ax in axs]) |
| 66 | + |
| 67 | + colormap = {"mul_ir": "#1f77b4", "ir_mul": "#2ca02c"} |
| 68 | + |
| 69 | + for row, direction in enumerate(directions): |
| 70 | + for col, dtype in enumerate(dtype_order): |
| 71 | + axis = axs[row][col] |
| 72 | + source = grouped.get(dtype, {}).get(direction, {}) |
| 73 | + data = { |
| 74 | + label: { |
| 75 | + "mul_ir": vals["mul_ir"], |
| 76 | + "ir_mul": vals["ir_mul"], |
| 77 | + } |
| 78 | + for label, vals in source.items() |
| 79 | + } |
| 80 | + grouped_barchart( |
| 81 | + data, |
| 82 | + axis, |
| 83 | + bar_height_fontsize=0, |
| 84 | + colormap=colormap, |
| 85 | + group_spacing=6.0, |
| 86 | + xticklabel=(row == len(directions) - 1), |
| 87 | + ) |
| 88 | + set_grid(axis) |
| 89 | + |
| 90 | + if row == 0: |
| 91 | + axis.set_title(dtype.replace("<class 'numpy.", "").replace("'>", "")) |
| 92 | + if col == 0: |
| 93 | + axis.set_ylabel(direction.capitalize()) |
| 94 | + if row < len(directions) - 1: |
| 95 | + axis.tick_params(axis="x", labelbottom=False) |
| 96 | + |
| 97 | + fig.supylabel("Throughput (# tensor products / s)", x=0.03, y=0.56) |
| 98 | + fig.supxlabel("Problem") |
| 99 | + |
| 100 | + handles, labels = axs[0][0].get_legend_handles_labels() |
| 101 | + unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]] |
| 102 | + if unique: |
| 103 | + axs[0][0].legend(*zip(*unique)) |
| 104 | + |
| 105 | + fig.tight_layout(rect=(0.03, 0.03, 1.0, 1.0)) |
| 106 | + fig.savefig(str(data_folder / "layout_throughput_comparison.pdf")) |
| 107 | + |
| 108 | + print("Layout speedups (ir_mul / mul_ir):") |
| 109 | + print("\t".join(["dtype", "direction", "min", "mean", "median", "max"])) |
| 110 | + for dtype in dtype_order: |
| 111 | + for direction in directions: |
| 112 | + ratios = [] |
| 113 | + for _, values in grouped.get(dtype, {}).get(direction, {}).items(): |
| 114 | + if values["mul_ir"] > 0: |
| 115 | + ratios.append(values["ir_mul"] / values["mul_ir"]) |
| 116 | + if ratios: |
| 117 | + stats = [np.min(ratios), np.mean(ratios), np.median(ratios), np.max(ratios)] |
| 118 | + stats_fmt = [f"{val:.3f}" for val in stats] |
| 119 | + print("\t".join([dtype, direction] + stats_fmt)) |
0 commit comments