66 SMEMCapacityException ,
77)
88
9- from openequivariance .core .dtype_enum import (
10- dtype_to_enum ,
11- enum_to_torch_dtype ,
12- )
9+ from openequivariance .core .dtype_enum import dtype_to_enum
1310from openequivariance .templates .jinja_utils import get_jinja_environment
14- import openequivariance .impl_torch .extlib as extlib
15- from openequivariance .impl_torch .extlib import JITConvImpl , postprocess_kernel , DeviceProp
16-
1711from openequivariance .core .utils import filter_and_analyze_problem
18- from openequivariance .benchmark .logging_utils import getLogger
19-
20- logger = getLogger ()
21-
2212
2313class LoopUnrollConv (ConvolutionBase ):
2414 def __init__ (
2515 self ,
2616 config ,
17+ dp , postprocess_kernel ,
2718 * ,
2819 idx_dtype : type [np .generic ] = np .int64 ,
2920 torch_op : bool = False ,
@@ -39,7 +30,6 @@ def __init__(
3930
4031 env = get_jinja_environment ()
4132 template = env .get_template ("loop_unroll_conv_atomic.cuh" )
42- dp = DeviceProp (0 )
4333
4434 analysis = filter_and_analyze_problem (config )
4535 self .is_uvw = analysis ["is_uvw" ]
@@ -141,10 +131,10 @@ def generate_double_backward_schedule(warps_per_block):
141131 self .backward_workspace_offset = None
142132 self .double_backwardB_offset = None
143133
144- workspace_size = 1
134+ self . workspace_size = 1
145135 if deterministic :
146136 destination_index_bytes = 32 # Add extra to account for padding
147- workspace_size = max (
137+ self . workspace_size = max (
148138 (
149139 self .forward_schedule .L3 .dim * np .dtype (config .irrep_dtype ).itemsize
150140 + destination_index_bytes
@@ -186,7 +176,19 @@ def generate_double_backward_schedule(warps_per_block):
186176 )
187177 self .double_backwardB_offset = (self .double_backwardB_offset + 7 ) // 8 * 8
188178
189- self .allocate_workspace (workspace_size )
179+ self .kernel_prop = {
180+ "L1_dim" : self .L1 .dim ,
181+ "L2_dim" : self .L2 .dim ,
182+ "L3_dim" : self .L3 .dim ,
183+ "weight_numel" : self .config .weight_numel ,
184+ "workspace_size" : self .workspace_size ,
185+ "opt_level" : 3 ,
186+ "shared_weights" : int (config .shared_weights ),
187+ "deterministic" : int (self .deterministic ),
188+ "irrep_dtype" : dtype_to_enum [self .config .irrep_dtype ],
189+ "weight_dtype" : dtype_to_enum [self .config .weight_dtype ],
190+ "idx_dtype" : dtype_to_enum [self .idx_dtype ],
191+ }
190192
191193 self .jit_kernel = template .render (
192194 forward_schedule = self .forward_schedule ,
@@ -199,255 +201,5 @@ def generate_double_backward_schedule(warps_per_block):
199201 )
200202 self .jit_kernel = postprocess_kernel (self .jit_kernel )
201203
202- if self .torch_op and extlib .TORCH_COMPILE :
203- global torch
204- import torch
205-
206- internal_cls = torch .classes .libtorch_tp_jit .TorchJITConv
207- else :
208- internal_cls = JITConvImpl
209-
210- logger .info ("Starting kernel compiler..." )
211- self .internal = internal_cls (
212- self .jit_kernel ,
213- vars (self .forward_schedule .launch_config ),
214- vars (self .backward_schedule .launch_config ),
215- vars (self .double_backward_schedule .launch_config ),
216- {
217- "L1_dim" : self .L1 .dim ,
218- "L2_dim" : self .L2 .dim ,
219- "L3_dim" : self .L3 .dim ,
220- "weight_numel" : self .config .weight_numel ,
221- "workspace_size" : self .workspace_size ,
222- "opt_level" : 3 ,
223- "shared_weights" : int (config .shared_weights ),
224- "deterministic" : int (self .deterministic ),
225- "irrep_dtype" : dtype_to_enum [self .config .irrep_dtype ],
226- "weight_dtype" : dtype_to_enum [self .config .weight_dtype ],
227- "idx_dtype" : dtype_to_enum [self .idx_dtype ],
228- },
229- )
230- logger .info ("Kernel compiled!" )
231-
232204 # with open("scratch.txt", "w") as f:
233205 # f.write(self.jit_kernel)
234-
235- def reorder_weights_from_e3nn (self , weights , has_batch_dim = True ):
236- return self .forward_schedule .reorder_weights_from_e3nn (weights , has_batch_dim )
237-
238- def reorder_weights_to_e3nn (self , weights , has_batch_dim = True ):
239- return self .forward_schedule .reorder_weights_to_e3nn (weights , has_batch_dim )
240-
241- @staticmethod
242- def name ():
243- return "LoopUnrollConv"
244-
245- @classmethod
246- def register_torch_fakes (cls ):
247- global torch
248- import torch
249-
250- @torch ._library .register_fake_class ("libtorch_tp_jit::TorchJITConv" )
251- class TorchJITConv :
252- def __init__ (
253- self ,
254- kernel_plaintext : str ,
255- fwd_config : dict [str , int ],
256- bwd_config : dict [str , int ],
257- dbl_bwd_config : dict [str , int ],
258- kernel_dims : dict [str , int ],
259- ) -> None :
260- (
261- self .kernel_plaintext ,
262- self .fwd_config ,
263- self .bwd_config ,
264- self .dbl_bwd_config ,
265- self .kernel_dims ,
266- ) = (
267- kernel_plaintext ,
268- fwd_config ,
269- bwd_config ,
270- dbl_bwd_config ,
271- kernel_dims ,
272- )
273-
274- @classmethod
275- def __obj_unflatten__ (cls , flattened_product ):
276- return cls (** dict (flattened_product ))
277-
278- def __len__ (self ):
279- return 0
280-
281- def __setstate__ (self , state ):
282- (
283- self .kernel_plaintext ,
284- self .fwd_config ,
285- self .bwd_config ,
286- self .dbl_bwd_config ,
287- self .kernel_dims ,
288- ) = state
289-
290- def exec_conv_rawptrs (* args , ** kwargs ):
291- pass
292-
293- def backward_rawptrs (* args , ** kwargs ):
294- pass
295-
296- def double_backward_rawptrs (* args , ** kwargs ):
297- pass
298-
299- def L3_dim_getter (self ):
300- return self .kernel_dims ["L3_dim" ]
301-
302- def irrep_dtype_getter (self ):
303- return self .kernel_dims ["irrep_dtype" ]
304-
305- @torch .library .register_fake ("libtorch_tp_jit::jit_conv_forward" )
306- def fake_forward (
307- jit , L1_in , L2_in , W , rows , cols , workspace_buffer , sender_perm
308- ):
309- L3_dim , irrep_dtype = None , None
310- if hasattr (jit , "wrapped_obj" ):
311- L3_dim = jit .wrapped_obj .kernel_dims ["L3_dim" ]
312- irrep_dtype = jit .wrapped_obj .kernel_dims ["irrep_dtype" ]
313- else :
314- L3_dim = jit .L3_dim
315- irrep_dtype = jit .irrep_dtype
316-
317- return torch .empty (
318- L1_in .shape [0 ],
319- L3_dim ,
320- device = "cuda" ,
321- dtype = enum_to_torch_dtype [irrep_dtype ],
322- )
323-
324- @torch .library .register_fake ("libtorch_tp_jit::jit_conv_backward" )
325- def fake_backward (
326- jit , L1_in , L2_in , W , L3_grad , rows , cols , workspace_buffer , sender_perm
327- ):
328- return torch .empty_like (L1_in ), torch .empty_like (L2_in ), torch .empty_like (W )
329-
330- @torch .library .register_fake ("libtorch_tp_jit::jit_conv_double_backward" )
331- def fake_double_backward (
332- jit ,
333- L1_in ,
334- L2_in ,
335- W ,
336- L3_grad ,
337- L1_dgrad ,
338- L2_dgrad ,
339- w_dgrad ,
340- rows ,
341- cols ,
342- workspace_buffer ,
343- transpose_perm = None ,
344- ):
345- return [
346- L1_in .new_empty (* L1_in .shape ),
347- L2_in .new_empty (* L2_in .shape ),
348- W .new_empty (* W .shape ),
349- L3_grad .new_empty (* L3_grad .shape ),
350- ]
351-
352- @classmethod
353- def register_autograd (cls ):
354- backward_op = torch .ops .libtorch_tp_jit .jit_conv_backward
355- double_backward_op = torch .ops .libtorch_tp_jit .jit_conv_double_backward
356-
357- def setup_context (ctx , inputs , output ):
358- (
359- ctx .jit ,
360- ctx .L1_in ,
361- ctx .L2_in ,
362- ctx .W ,
363- ctx .rows ,
364- ctx .cols ,
365- ctx .workspace_buffer ,
366- ctx .sender_perm ,
367- ) = inputs
368-
369- def backward (ctx , grad_output ):
370- L1_grad , L2_grad , W_grad = backward_op (
371- ctx .jit ,
372- ctx .L1_in ,
373- ctx .L2_in ,
374- ctx .W ,
375- grad_output ,
376- ctx .rows ,
377- ctx .cols ,
378- ctx .workspace_buffer ,
379- ctx .sender_perm ,
380- )
381- return None , L1_grad , L2_grad , W_grad , None , None , None , None
382-
383- torch .library .register_autograd (
384- "libtorch_tp_jit::jit_conv_forward" , backward , setup_context = setup_context
385- )
386-
387- def setup_context_double_backward (ctx , inputs , output ):
388- (
389- ctx .jit ,
390- ctx .L1_in ,
391- ctx .L2_in ,
392- ctx .W ,
393- ctx .grad_output ,
394- ctx .rows ,
395- ctx .cols ,
396- ctx .workspace_buffer ,
397- ctx .sender_perm ,
398- ) = inputs
399- ctx .inputs = inputs
400-
401- def double_backward (ctx , E , F , G ):
402- result = double_backward_op (
403- ctx .jit ,
404- ctx .L1_in ,
405- ctx .L2_in ,
406- ctx .W ,
407- ctx .grad_output ,
408- E ,
409- F ,
410- G ,
411- ctx .rows ,
412- ctx .cols ,
413- ctx .workspace_buffer ,
414- ctx .sender_perm ,
415- )
416- return (
417- None ,
418- result [0 ],
419- result [1 ],
420- result [2 ],
421- result [3 ],
422- None ,
423- None ,
424- None ,
425- None ,
426- )
427-
428- torch .library .register_autograd (
429- "libtorch_tp_jit::jit_conv_backward" ,
430- double_backward ,
431- setup_context = setup_context_double_backward ,
432- )
433-
434- @classmethod
435- def register_autocast (cls ):
436- global torch
437- import torch
438-
439- torch .library .register_autocast (
440- "libtorch_tp_jit::jit_conv_forward" , "cuda" , torch .float32
441- )
442- torch .library .register_autocast (
443- "libtorch_tp_jit::jit_conv_backward" , "cuda" , torch .float32
444- )
445- torch .library .register_autocast (
446- "libtorch_tp_jit::jit_conv_double_backward" , "cuda" , torch .float32
447- )
448-
449-
450- if extlib .TORCH_COMPILE :
451- LoopUnrollConv .register_torch_fakes ()
452- LoopUnrollConv .register_autograd ()
453- LoopUnrollConv .register_autocast ()
0 commit comments