11import numpy as np
2- import tempfile
32import json
43import os
54import itertools
1312 FullyConnectedTPProblem ,
1413 SingleInstruction ,
1514)
16- from openequivariance .extlib import GPUTimer
1715from openequivariance .implementations .utils import count_cg_non_zero
1816
1917os .environ ["CUEQUIVARIANCE_OPS_USE_JIT" ] = "1"
@@ -123,6 +121,12 @@ def iterator(cls) -> Iterator["O3_e3nn"]:
123121 self .tp_correctness .to ("cuda" )
124122 self .forward_correctness = lambda x , y , W : self .tp_correctness (W , x , y )
125123
124+ self .kernel_names = [
125+ "TensorProductUniform1dKernel" ,
126+ "channelwise_kernel_fwd" ,
127+ "channelwise_kernel_bwd" ,
128+ ]
129+
126130 def forward_cpu (
127131 self ,
128132 L1_in : np .ndarray ,
@@ -197,42 +201,18 @@ def benchmark_forward(
197201 L2_in : np .ndarray ,
198202 L3_buffer : np .ndarray ,
199203 weights : np .ndarray ,
204+ with_torch_overhead : bool = True ,
200205 ) -> np .ndarray :
201- """
202- When we don't want to include torch overhead, we use the Pytorch profiler
203- to extract the device time that the kernel takes.
204- """
205- if self .torch_op :
206- return super ().benchmark_forward (
207- num_warmup , num_iter , L1_in , L2_in , L3_buffer , weights
208- )
209- else :
210- from torch .profiler import profile , record_function , ProfilerActivity
211-
212- time_millis = np .zeros (num_iter , dtype = np .float32 )
213- torch_L1_in = torch .tensor (L1_in ).to (device = "cuda" ).detach ()
214- torch_L2_in = torch .tensor (L2_in ).to (device = "cuda" ).detach ()
215- torch_weights = torch .tensor (weights ).to (device = "cuda" ).detach ()
216-
217- timer = GPUTimer ()
218-
219- for i in range (num_warmup ):
220- self .forward (torch_L1_in , torch_L2_in , torch_weights )
221-
222- trace_file = tempfile .NamedTemporaryFile ().name
223-
224- for i in range (num_iter ):
225- timer .clear_L2_cache ()
226- with profile (
227- activities = [ProfilerActivity .CUDA ], record_shapes = True
228- ) as prof :
229- with record_function ("cue_forward" ):
230- self .forward (torch_L1_in , torch_L2_in , torch_weights )
231-
232- prof .export_chrome_trace (trace_file )
233- time_millis [i ] = self .analyze_trace (trace_file )
234-
235- return time_millis
206+ return super ().benchmark_forward (
207+ num_warmup ,
208+ num_iter ,
209+ L1_in ,
210+ L2_in ,
211+ L3_buffer ,
212+ weights ,
213+ with_torch_overhead ,
214+ kernel_names = ["TensorProductUniform1DKernel" , "channelwise_kernel_" ],
215+ )
236216
237217 def benchmark_backward (
238218 self ,
@@ -242,60 +222,18 @@ def benchmark_backward(
242222 L2_in : np .ndarray ,
243223 L3_buffer : np .ndarray ,
244224 weights : np .ndarray ,
245- L1_grad : np .ndarray ,
246- L2_grad : np .ndarray ,
247- weights_grad : np .ndarray ,
225+ with_torch_overhead : bool = True ,
248226 ) -> np .ndarray :
249- if self .torch_op :
250- return super ().benchmark_backward (
251- num_warmup ,
252- num_iter ,
253- L1_in ,
254- L2_in ,
255- L3_buffer ,
256- weights ,
257- L1_grad ,
258- L2_grad ,
259- weights_grad ,
260- )
261- else :
262- from torch .profiler import profile , record_function , ProfilerActivity
263-
264- time_millis = np .zeros (num_iter , dtype = np .float32 )
265-
266- torch_L1_in = torch .tensor (L1_in , requires_grad = True , device = "cuda" )
267- torch_L2_in = torch .tensor (L2_in , requires_grad = True , device = "cuda" )
268- torch_weights = torch .tensor (weights , requires_grad = True , device = "cuda" )
269- torch_out = self .forward (torch_L1_in , torch_L2_in , torch_weights )
270- torch_L3_grad_in = torch .tensor (L3_buffer , device = "cuda" )
271-
272- timer = GPUTimer ()
273-
274- for i in range (num_warmup ):
275- torch_out .backward (
276- gradient = torch_L3_grad_in ,
277- retain_graph = True ,
278- inputs = [torch_L1_in , torch_L2_in , torch_weights ],
279- )
280-
281- trace_file = tempfile .NamedTemporaryFile ().name
282-
283- for i in range (num_iter ):
284- timer .clear_L2_cache ()
285- with profile (
286- activities = [ProfilerActivity .CUDA ], record_shapes = True
287- ) as prof :
288- with record_function ("cue_backward" ):
289- torch_out .backward (
290- gradient = torch_L3_grad_in ,
291- retain_graph = True ,
292- inputs = [torch_L1_in , torch_L2_in , torch_weights ],
293- )
294-
295- prof .export_chrome_trace (trace_file )
296- time_millis [i ] = self .analyze_trace (trace_file )
297-
298- return time_millis
227+ return super ().benchmark_backward (
228+ num_warmup ,
229+ num_iter ,
230+ L1_in ,
231+ L2_in ,
232+ L3_buffer ,
233+ weights ,
234+ with_torch_overhead ,
235+ kernel_names = self .kernel_names ,
236+ )
299237
300238 # Copied over from loop unroller to match arithmetic intensity on roofline plots
301239 def calculate_flops_forward (self , batch_size : int ) -> dict :
0 commit comments