1+ import numpy as np
2+ import matplotlib .pyplot as plt
3+ import os , json , pathlib , sys
4+ from openequivariance .benchmark .plotting import *
5+
6+ def plot_double_backward (data_folder ):
7+ data_folder = pathlib .Path (data_folder )
8+ benchmarks , metadata = load_benchmarks (data_folder )
9+
10+ configs = metadata ["config_labels" ]
11+ implementations = ["E3NNTensorProduct" , "CUETensorProduct" , "LoopUnrollTP" ]
12+
13+ def calculate_tp_per_sec (exp ):
14+ return exp ["benchmark results" ]["batch_size" ] / (np .mean (exp ["benchmark results" ]["time_millis" ]) * 0.001 )
15+
16+ dataf32 = {"double_backward" : {}}
17+ for i , desc in enumerate (configs ):
18+ for direction in ["double_backward" ]:
19+ dataf32 [direction ][desc ] = {}
20+ for impl in implementations :
21+ f32_benches = [b for b in benchmarks if b ["benchmark results" ]["rep_dtype" ] == "<class 'numpy.float32'>" ]
22+ exp = filter (f32_benches , {"config_label" : desc ,
23+ "direction" : direction ,
24+ "implementation_name" : impl
25+ }, match_one = True )
26+ dataf32 [direction ][desc ][labelmap [impl ]] = calculate_tp_per_sec (exp )
27+
28+ dataf64 = {"double_backward" : {}}
29+ for i , desc in enumerate (configs ):
30+ for direction in ["double_backward" ]:
31+ dataf64 [direction ][desc ] = {}
32+ for impl in implementations :
33+ f64_benches = [b for b in benchmarks if 'float64' in b ["benchmark results" ]["rep_dtype" ]]
34+
35+ exp = filter (f64_benches , {"config_label" : desc ,
36+ "direction" : direction ,
37+ "implementation_name" : impl
38+ }, match_one = True )
39+
40+ if exp is None :
41+ print (desc )
42+ print (direction )
43+ print (impl )
44+
45+ dataf64 [direction ][desc ][labelmap [impl ]] = calculate_tp_per_sec (exp )
46+
47+ fig = plt .figure (figsize = (7 , 3 ))
48+ gs = fig .add_gridspec (1 , 2 , hspace = 0 , wspace = 0.1 )
49+ axs = gs .subplots (sharex = 'col' , sharey = 'row' )
50+
51+ grouped_barchart (dataf32 ["double_backward" ], axs [0 ], bar_height_fontsize = 0 , colormap = colormap , group_spacing = 6.0 )
52+ grouped_barchart (dataf64 ["double_backward" ], axs [1 ], bar_height_fontsize = 0 , colormap = colormap , group_spacing = 6.0 )
53+
54+ for i in range (2 ):
55+ set_grid (axs [i ])
56+ set_grid (axs [i ])
57+
58+ axs [0 ].set_xlabel ("float32" )
59+ axs [1 ].set_xlabel ("float64" )
60+
61+ handles , labels = axs [0 ].get_legend_handles_labels ()
62+ unique = [(h , l ) for i , (h , l ) in enumerate (zip (handles , labels )) if l not in labels [:i ]]
63+ axs [0 ].legend (* zip (* unique ))
64+
65+ for ax in fig .get_axes ():
66+ ax .label_outer ()
67+
68+ fig .supylabel ("2nd Deriv. Throughput\n (# tensor products / s)" , y = 0.5 )
69+
70+ speedup_table = []
71+ for direction in ['double_backward' ]:
72+ for impl in ['e3nn' , 'cuE' ]:
73+ for dtype_label , dtype_set in [('f32' , dataf32 ), ('f64' , dataf64 )]:
74+ speedups = [measurement ['ours' ] / measurement [impl ] for _ , measurement in dtype_set [direction ].items () if impl in measurement ]
75+ stats = np .min (speedups ), np .mean (speedups ), np .median (speedups ), np .max (speedups )
76+ stats = [f"{ stat :.2f} " for stat in stats ]
77+
78+ dir_print = direction
79+ result = [dir_print , impl , dtype_label ] + stats
80+ speedup_table .append (result )
81+
82+ print ('\t \t ' .join (['Direction' , 'Base' , 'dtype' , 'min' , 'mean' , 'med' , 'max' ]))
83+ for row in speedup_table :
84+ print ('\t \t ' .join (row ))
85+
86+ fig .show ()
87+ fig .tight_layout ()
88+ fig .savefig (str (data_folder / "double_backward_throughput.pdf" ), bbox_inches = 'tight' )
0 commit comments