Skip to content

Commit 38017b0

Browse files
committed
Reorg of LoopUnrollConv.py
1 parent 7b1ce90 commit 38017b0

3 files changed

Lines changed: 271 additions & 281 deletions

File tree

openequivariance/openequivariance/core/ConvolutionBase.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import numpy as np
3-
from openequivariance.impl_torch.extlib import DeviceBuffer
43
from openequivariance.benchmark.random_buffer_utils import (
54
get_random_buffers_forward_conv,
65
get_random_buffers_backward_conv,
@@ -130,17 +129,6 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
130129
"""
131130
return weights
132131

133-
def allocate_workspace(self, size_bytes):
134-
self.workspace_size = size_bytes
135-
if self.torch_op:
136-
self.workspace_buffer = torch.zeros(
137-
size_bytes, dtype=torch.uint8, device="cuda"
138-
)
139-
else:
140-
self.workspace_buffer = DeviceBuffer(size_bytes)
141-
self.workspace_ptr = self.workspace_buffer.data_ptr()
142-
logger.info(f"Convolution requires {size_bytes // 1000000}MB of workspace.")
143-
144132
@staticmethod
145133
def name():
146134
raise NotImplementedError()

openequivariance/openequivariance/core/LoopUnrollConv.py

Lines changed: 17 additions & 265 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,15 @@
66
SMEMCapacityException,
77
)
88

9-
from openequivariance.core.dtype_enum import (
10-
dtype_to_enum,
11-
enum_to_torch_dtype,
12-
)
9+
from openequivariance.core.dtype_enum import dtype_to_enum
1310
from openequivariance.templates.jinja_utils import get_jinja_environment
14-
import openequivariance.impl_torch.extlib as extlib
15-
from openequivariance.impl_torch.extlib import JITConvImpl, postprocess_kernel, DeviceProp
16-
1711
from openequivariance.core.utils import filter_and_analyze_problem
18-
from openequivariance.benchmark.logging_utils import getLogger
19-
20-
logger = getLogger()
21-
2212

2313
class LoopUnrollConv(ConvolutionBase):
2414
def __init__(
2515
self,
2616
config,
17+
dp, postprocess_kernel,
2718
*,
2819
idx_dtype: type[np.generic] = np.int64,
2920
torch_op: bool = False,
@@ -39,7 +30,6 @@ def __init__(
3930

4031
env = get_jinja_environment()
4132
template = env.get_template("loop_unroll_conv_atomic.cuh")
42-
dp = DeviceProp(0)
4333

4434
analysis = filter_and_analyze_problem(config)
4535
self.is_uvw = analysis["is_uvw"]
@@ -141,10 +131,10 @@ def generate_double_backward_schedule(warps_per_block):
141131
self.backward_workspace_offset = None
142132
self.double_backwardB_offset = None
143133

144-
workspace_size = 1
134+
self.workspace_size = 1
145135
if deterministic:
146136
destination_index_bytes = 32 # Add extra to account for padding
147-
workspace_size = max(
137+
self.workspace_size = max(
148138
(
149139
self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize
150140
+ destination_index_bytes
@@ -186,7 +176,19 @@ def generate_double_backward_schedule(warps_per_block):
186176
)
187177
self.double_backwardB_offset = (self.double_backwardB_offset + 7) // 8 * 8
188178

189-
self.allocate_workspace(workspace_size)
179+
self.kernel_prop = {
180+
"L1_dim": self.L1.dim,
181+
"L2_dim": self.L2.dim,
182+
"L3_dim": self.L3.dim,
183+
"weight_numel": self.config.weight_numel,
184+
"workspace_size": self.workspace_size,
185+
"opt_level": 3,
186+
"shared_weights": int(config.shared_weights),
187+
"deterministic": int(self.deterministic),
188+
"irrep_dtype": dtype_to_enum[self.config.irrep_dtype],
189+
"weight_dtype": dtype_to_enum[self.config.weight_dtype],
190+
"idx_dtype": dtype_to_enum[self.idx_dtype],
191+
}
190192

191193
self.jit_kernel = template.render(
192194
forward_schedule=self.forward_schedule,
@@ -199,255 +201,5 @@ def generate_double_backward_schedule(warps_per_block):
199201
)
200202
self.jit_kernel = postprocess_kernel(self.jit_kernel)
201203

202-
if self.torch_op and extlib.TORCH_COMPILE:
203-
global torch
204-
import torch
205-
206-
internal_cls = torch.classes.libtorch_tp_jit.TorchJITConv
207-
else:
208-
internal_cls = JITConvImpl
209-
210-
logger.info("Starting kernel compiler...")
211-
self.internal = internal_cls(
212-
self.jit_kernel,
213-
vars(self.forward_schedule.launch_config),
214-
vars(self.backward_schedule.launch_config),
215-
vars(self.double_backward_schedule.launch_config),
216-
{
217-
"L1_dim": self.L1.dim,
218-
"L2_dim": self.L2.dim,
219-
"L3_dim": self.L3.dim,
220-
"weight_numel": self.config.weight_numel,
221-
"workspace_size": self.workspace_size,
222-
"opt_level": 3,
223-
"shared_weights": int(config.shared_weights),
224-
"deterministic": int(self.deterministic),
225-
"irrep_dtype": dtype_to_enum[self.config.irrep_dtype],
226-
"weight_dtype": dtype_to_enum[self.config.weight_dtype],
227-
"idx_dtype": dtype_to_enum[self.idx_dtype],
228-
},
229-
)
230-
logger.info("Kernel compiled!")
231-
232204
# with open("scratch.txt", "w") as f:
233205
# f.write(self.jit_kernel)
234-
235-
def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
236-
return self.forward_schedule.reorder_weights_from_e3nn(weights, has_batch_dim)
237-
238-
def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
239-
return self.forward_schedule.reorder_weights_to_e3nn(weights, has_batch_dim)
240-
241-
@staticmethod
242-
def name():
243-
return "LoopUnrollConv"
244-
245-
@classmethod
246-
def register_torch_fakes(cls):
247-
global torch
248-
import torch
249-
250-
@torch._library.register_fake_class("libtorch_tp_jit::TorchJITConv")
251-
class TorchJITConv:
252-
def __init__(
253-
self,
254-
kernel_plaintext: str,
255-
fwd_config: dict[str, int],
256-
bwd_config: dict[str, int],
257-
dbl_bwd_config: dict[str, int],
258-
kernel_dims: dict[str, int],
259-
) -> None:
260-
(
261-
self.kernel_plaintext,
262-
self.fwd_config,
263-
self.bwd_config,
264-
self.dbl_bwd_config,
265-
self.kernel_dims,
266-
) = (
267-
kernel_plaintext,
268-
fwd_config,
269-
bwd_config,
270-
dbl_bwd_config,
271-
kernel_dims,
272-
)
273-
274-
@classmethod
275-
def __obj_unflatten__(cls, flattened_product):
276-
return cls(**dict(flattened_product))
277-
278-
def __len__(self):
279-
return 0
280-
281-
def __setstate__(self, state):
282-
(
283-
self.kernel_plaintext,
284-
self.fwd_config,
285-
self.bwd_config,
286-
self.dbl_bwd_config,
287-
self.kernel_dims,
288-
) = state
289-
290-
def exec_conv_rawptrs(*args, **kwargs):
291-
pass
292-
293-
def backward_rawptrs(*args, **kwargs):
294-
pass
295-
296-
def double_backward_rawptrs(*args, **kwargs):
297-
pass
298-
299-
def L3_dim_getter(self):
300-
return self.kernel_dims["L3_dim"]
301-
302-
def irrep_dtype_getter(self):
303-
return self.kernel_dims["irrep_dtype"]
304-
305-
@torch.library.register_fake("libtorch_tp_jit::jit_conv_forward")
306-
def fake_forward(
307-
jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm
308-
):
309-
L3_dim, irrep_dtype = None, None
310-
if hasattr(jit, "wrapped_obj"):
311-
L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"]
312-
irrep_dtype = jit.wrapped_obj.kernel_dims["irrep_dtype"]
313-
else:
314-
L3_dim = jit.L3_dim
315-
irrep_dtype = jit.irrep_dtype
316-
317-
return torch.empty(
318-
L1_in.shape[0],
319-
L3_dim,
320-
device="cuda",
321-
dtype=enum_to_torch_dtype[irrep_dtype],
322-
)
323-
324-
@torch.library.register_fake("libtorch_tp_jit::jit_conv_backward")
325-
def fake_backward(
326-
jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm
327-
):
328-
return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W)
329-
330-
@torch.library.register_fake("libtorch_tp_jit::jit_conv_double_backward")
331-
def fake_double_backward(
332-
jit,
333-
L1_in,
334-
L2_in,
335-
W,
336-
L3_grad,
337-
L1_dgrad,
338-
L2_dgrad,
339-
w_dgrad,
340-
rows,
341-
cols,
342-
workspace_buffer,
343-
transpose_perm=None,
344-
):
345-
return [
346-
L1_in.new_empty(*L1_in.shape),
347-
L2_in.new_empty(*L2_in.shape),
348-
W.new_empty(*W.shape),
349-
L3_grad.new_empty(*L3_grad.shape),
350-
]
351-
352-
@classmethod
353-
def register_autograd(cls):
354-
backward_op = torch.ops.libtorch_tp_jit.jit_conv_backward
355-
double_backward_op = torch.ops.libtorch_tp_jit.jit_conv_double_backward
356-
357-
def setup_context(ctx, inputs, output):
358-
(
359-
ctx.jit,
360-
ctx.L1_in,
361-
ctx.L2_in,
362-
ctx.W,
363-
ctx.rows,
364-
ctx.cols,
365-
ctx.workspace_buffer,
366-
ctx.sender_perm,
367-
) = inputs
368-
369-
def backward(ctx, grad_output):
370-
L1_grad, L2_grad, W_grad = backward_op(
371-
ctx.jit,
372-
ctx.L1_in,
373-
ctx.L2_in,
374-
ctx.W,
375-
grad_output,
376-
ctx.rows,
377-
ctx.cols,
378-
ctx.workspace_buffer,
379-
ctx.sender_perm,
380-
)
381-
return None, L1_grad, L2_grad, W_grad, None, None, None, None
382-
383-
torch.library.register_autograd(
384-
"libtorch_tp_jit::jit_conv_forward", backward, setup_context=setup_context
385-
)
386-
387-
def setup_context_double_backward(ctx, inputs, output):
388-
(
389-
ctx.jit,
390-
ctx.L1_in,
391-
ctx.L2_in,
392-
ctx.W,
393-
ctx.grad_output,
394-
ctx.rows,
395-
ctx.cols,
396-
ctx.workspace_buffer,
397-
ctx.sender_perm,
398-
) = inputs
399-
ctx.inputs = inputs
400-
401-
def double_backward(ctx, E, F, G):
402-
result = double_backward_op(
403-
ctx.jit,
404-
ctx.L1_in,
405-
ctx.L2_in,
406-
ctx.W,
407-
ctx.grad_output,
408-
E,
409-
F,
410-
G,
411-
ctx.rows,
412-
ctx.cols,
413-
ctx.workspace_buffer,
414-
ctx.sender_perm,
415-
)
416-
return (
417-
None,
418-
result[0],
419-
result[1],
420-
result[2],
421-
result[3],
422-
None,
423-
None,
424-
None,
425-
None,
426-
)
427-
428-
torch.library.register_autograd(
429-
"libtorch_tp_jit::jit_conv_backward",
430-
double_backward,
431-
setup_context=setup_context_double_backward,
432-
)
433-
434-
@classmethod
435-
def register_autocast(cls):
436-
global torch
437-
import torch
438-
439-
torch.library.register_autocast(
440-
"libtorch_tp_jit::jit_conv_forward", "cuda", torch.float32
441-
)
442-
torch.library.register_autocast(
443-
"libtorch_tp_jit::jit_conv_backward", "cuda", torch.float32
444-
)
445-
torch.library.register_autocast(
446-
"libtorch_tp_jit::jit_conv_double_backward", "cuda", torch.float32
447-
)
448-
449-
450-
if extlib.TORCH_COMPILE:
451-
LoopUnrollConv.register_torch_fakes()
452-
LoopUnrollConv.register_autograd()
453-
LoopUnrollConv.register_autocast()

0 commit comments

Comments
 (0)