Skip to content

Commit 54efff0

Browse files
Compacted benchmarking code. (#146)
* Making progress on updated benchmarking. * Changed the first benchmarking function. * More progress on benchmarking. * Updated benchmarking code for double backward pass. * Removed unneeded arguments. * Cut down some more benchmarking code. * Backward benchmarking code updated. * Removed unecessary benchmarking code from CUETensorProduct. * More small changes. * Linted.
1 parent 22c8bb9 commit 54efff0

10 files changed

Lines changed: 243 additions & 548 deletions

File tree

openequivariance/benchmark/TestBenchmarkSuite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def run(
184184
num_warmup=self.num_warmup,
185185
num_iter=self.num_iter,
186186
prng_seed=self.prng_seed,
187-
torch_op=self.torch_op,
187+
with_torch_overhead=self.torch_op,
188188
)
189189

190190
if test.direction == "backward":
@@ -207,7 +207,7 @@ def run(
207207
num_warmup=self.num_warmup,
208208
num_iter=self.num_iter,
209209
prng_seed=self.prng_seed,
210-
torch_op=self.torch_op,
210+
with_torch_overhead=self.torch_op,
211211
)
212212

213213
if test.direction == "double_backward":
@@ -230,7 +230,7 @@ def run(
230230
num_warmup=self.num_warmup,
231231
num_iter=self.num_iter,
232232
prng_seed=self.prng_seed,
233-
torch_op=self.torch_op,
233+
with_torch_overhead=self.torch_op,
234234
)
235235

236236
fname = pathlib.Path(f"{output_folder}/{test_ID}_{impl.name()}.json")

openequivariance/benchmark/benchmark_utils.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def benchmark_forward(
7070
num_warmup: int,
7171
num_iter: int,
7272
prng_seed: int,
73-
torch_op: bool,
73+
with_torch_overhead: bool,
7474
) -> dict:
7575
"""
7676
This function sets up the necessary materials and calls the internal benchmarker
@@ -89,7 +89,7 @@ def benchmark_forward(
8989
weights = weights[np.newaxis, :]
9090

9191
logger.info("Initialized input / output data.")
92-
tp = implementation(problem, torch_op=torch_op)
92+
tp = implementation(problem)
9393

9494
# BENCHMARK
9595
try:
@@ -100,6 +100,7 @@ def benchmark_forward(
100100
L2_in=L2_in,
101101
weights=weights,
102102
L3_buffer=L3_buffer,
103+
with_torch_overhead=with_torch_overhead,
103104
)
104105
except NotImplementedError:
105106
logger.warning(
@@ -145,7 +146,7 @@ def benchmark_backward(
145146
num_warmup: int,
146147
num_iter: int,
147148
prng_seed: int,
148-
torch_op: bool,
149+
with_torch_overhead: bool,
149150
) -> dict:
150151
result = {
151152
"tp_direction": "backward",
@@ -161,7 +162,7 @@ def benchmark_backward(
161162
weights = weights[np.newaxis, :]
162163

163164
logger.info("Initialized input / output data.")
164-
tp = implementation(problem, torch_op=torch_op)
165+
tp = implementation(problem)
165166

166167
try:
167168
time_millis = tp.benchmark_backward(
@@ -171,9 +172,7 @@ def benchmark_backward(
171172
L2_in=in2,
172173
L3_buffer=out_grad,
173174
weights=weights,
174-
L1_grad=in1_grad,
175-
L2_grad=in2_grad,
176-
weights_grad=weights_grad,
175+
with_torch_overhead=with_torch_overhead,
177176
)
178177
except NotImplementedError:
179178
logger.warning(
@@ -223,7 +222,7 @@ def benchmark_double_backward(
223222
num_warmup: int,
224223
num_iter: int,
225224
prng_seed: int,
226-
torch_op: bool,
225+
with_torch_overhead: bool,
227226
) -> dict:
228227
result = {
229228
"tp_direction": "double_backward",
@@ -240,20 +239,17 @@ def benchmark_double_backward(
240239
weights = weights[np.newaxis, :]
241240

242241
logger.info("Initialized input / output data.")
243-
tp = implementation(problem, torch_op=torch_op)
242+
tp = implementation(problem)
244243

245244
try:
246245
time_millis = tp.benchmark_double_backward(
247246
num_warmup=num_warmup,
248247
num_iter=num_iter,
249248
L1_in=in1,
250249
L2_in=in2,
251-
L3_buffer=out_grad,
252250
weights=weights,
253-
L1_grad=in1_grad,
254-
L2_grad=in2_grad,
255251
weights_grad=weights_grad,
256-
L3_double_grad=out_double_grad,
252+
with_torch_overhead=with_torch_overhead,
257253
)
258254
except NotImplementedError:
259255
logger.warning(

openequivariance/benchmark/perf_metrics_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
count_cg_non_zero,
55
sparse_outer_product_work,
66
)
7-
from openequivariance.implementations.TensorProductBase import TensorProductBase
8-
from openequivariance.implementations.e3nn_lite import TPProblem
7+
8+
from openequivariance.implementations.e3nn_lite import TPProblem, wigner_3j
99
from openequivariance.benchmark.logging_utils import getLogger
1010
import numpy as np
1111

@@ -70,7 +70,7 @@ def calculate_minimum_flops_forward(tpp: TPProblem, batch_size: int) -> dict:
7070
)
7171

7272
flops_count["outer_products"] += sparse_outer_product_work(
73-
TensorProductBase.load_cg_tensor(l1, l2, l3)
73+
wigner_3j(l1, l2, l3)
7474
)
7575
flops_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * (
7676
ins.path_shape[0] * ins.path_shape[1]

openequivariance/implementations/CUETensorProduct.py

Lines changed: 28 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
import tempfile
32
import json
43
import os
54
import itertools
@@ -13,7 +12,6 @@
1312
FullyConnectedTPProblem,
1413
SingleInstruction,
1514
)
16-
from openequivariance.extlib import GPUTimer
1715
from openequivariance.implementations.utils import count_cg_non_zero
1816

1917
os.environ["CUEQUIVARIANCE_OPS_USE_JIT"] = "1"
@@ -123,6 +121,12 @@ def iterator(cls) -> Iterator["O3_e3nn"]:
123121
self.tp_correctness.to("cuda")
124122
self.forward_correctness = lambda x, y, W: self.tp_correctness(W, x, y)
125123

124+
self.kernel_names = [
125+
"TensorProductUniform1dKernel",
126+
"channelwise_kernel_fwd",
127+
"channelwise_kernel_bwd",
128+
]
129+
126130
def forward_cpu(
127131
self,
128132
L1_in: np.ndarray,
@@ -197,42 +201,18 @@ def benchmark_forward(
197201
L2_in: np.ndarray,
198202
L3_buffer: np.ndarray,
199203
weights: np.ndarray,
204+
with_torch_overhead: bool = True,
200205
) -> np.ndarray:
201-
"""
202-
When we don't want to include torch overhead, we use the Pytorch profiler
203-
to extract the device time that the kernel takes.
204-
"""
205-
if self.torch_op:
206-
return super().benchmark_forward(
207-
num_warmup, num_iter, L1_in, L2_in, L3_buffer, weights
208-
)
209-
else:
210-
from torch.profiler import profile, record_function, ProfilerActivity
211-
212-
time_millis = np.zeros(num_iter, dtype=np.float32)
213-
torch_L1_in = torch.tensor(L1_in).to(device="cuda").detach()
214-
torch_L2_in = torch.tensor(L2_in).to(device="cuda").detach()
215-
torch_weights = torch.tensor(weights).to(device="cuda").detach()
216-
217-
timer = GPUTimer()
218-
219-
for i in range(num_warmup):
220-
self.forward(torch_L1_in, torch_L2_in, torch_weights)
221-
222-
trace_file = tempfile.NamedTemporaryFile().name
223-
224-
for i in range(num_iter):
225-
timer.clear_L2_cache()
226-
with profile(
227-
activities=[ProfilerActivity.CUDA], record_shapes=True
228-
) as prof:
229-
with record_function("cue_forward"):
230-
self.forward(torch_L1_in, torch_L2_in, torch_weights)
231-
232-
prof.export_chrome_trace(trace_file)
233-
time_millis[i] = self.analyze_trace(trace_file)
234-
235-
return time_millis
206+
return super().benchmark_forward(
207+
num_warmup,
208+
num_iter,
209+
L1_in,
210+
L2_in,
211+
L3_buffer,
212+
weights,
213+
with_torch_overhead,
214+
kernel_names=["TensorProductUniform1DKernel", "channelwise_kernel_"],
215+
)
236216

237217
def benchmark_backward(
238218
self,
@@ -242,60 +222,18 @@ def benchmark_backward(
242222
L2_in: np.ndarray,
243223
L3_buffer: np.ndarray,
244224
weights: np.ndarray,
245-
L1_grad: np.ndarray,
246-
L2_grad: np.ndarray,
247-
weights_grad: np.ndarray,
225+
with_torch_overhead: bool = True,
248226
) -> np.ndarray:
249-
if self.torch_op:
250-
return super().benchmark_backward(
251-
num_warmup,
252-
num_iter,
253-
L1_in,
254-
L2_in,
255-
L3_buffer,
256-
weights,
257-
L1_grad,
258-
L2_grad,
259-
weights_grad,
260-
)
261-
else:
262-
from torch.profiler import profile, record_function, ProfilerActivity
263-
264-
time_millis = np.zeros(num_iter, dtype=np.float32)
265-
266-
torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda")
267-
torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda")
268-
torch_weights = torch.tensor(weights, requires_grad=True, device="cuda")
269-
torch_out = self.forward(torch_L1_in, torch_L2_in, torch_weights)
270-
torch_L3_grad_in = torch.tensor(L3_buffer, device="cuda")
271-
272-
timer = GPUTimer()
273-
274-
for i in range(num_warmup):
275-
torch_out.backward(
276-
gradient=torch_L3_grad_in,
277-
retain_graph=True,
278-
inputs=[torch_L1_in, torch_L2_in, torch_weights],
279-
)
280-
281-
trace_file = tempfile.NamedTemporaryFile().name
282-
283-
for i in range(num_iter):
284-
timer.clear_L2_cache()
285-
with profile(
286-
activities=[ProfilerActivity.CUDA], record_shapes=True
287-
) as prof:
288-
with record_function("cue_backward"):
289-
torch_out.backward(
290-
gradient=torch_L3_grad_in,
291-
retain_graph=True,
292-
inputs=[torch_L1_in, torch_L2_in, torch_weights],
293-
)
294-
295-
prof.export_chrome_trace(trace_file)
296-
time_millis[i] = self.analyze_trace(trace_file)
297-
298-
return time_millis
227+
return super().benchmark_backward(
228+
num_warmup,
229+
num_iter,
230+
L1_in,
231+
L2_in,
232+
L3_buffer,
233+
weights,
234+
with_torch_overhead,
235+
kernel_names=self.kernel_names,
236+
)
299237

300238
# Copied over from loop unroller to match arithmetic intensity on roofline plots
301239
def calculate_flops_forward(self, batch_size: int) -> dict:

openequivariance/implementations/ComputationSchedule.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import numpy as np
2-
from openequivariance.implementations.e3nn_lite import Irreps, TPProblem
2+
from openequivariance.implementations.e3nn_lite import Irreps, TPProblem, wigner_3j
33
from itertools import accumulate
44
from openequivariance.benchmark.logging_utils import getLogger
5-
from openequivariance.implementations.TensorProductBase import TensorProductBase
65

76
logger = getLogger()
87

@@ -60,7 +59,7 @@ class CGTensor:
6059
def __init__(self, l1, l2, l3, normalization_factor, dtype):
6160
suffix_map = {np.float32: "f", np.float64: "L"}
6261

63-
tensor = TensorProductBase.load_cg_tensor(l1, l2, l3)
62+
tensor = wigner_3j(l1, l2, l3)
6463
coord1, coord2, coord3 = [
6564
arr.astype(np.int32).copy() for arr in np.nonzero(tensor)
6665
]

0 commit comments

Comments
 (0)