Skip to content

Commit 6f288d6

Browse files
Patched convolution for GPUs with lower shared memory. (#93)
1 parent 5a9b80a commit 6f288d6

1 file changed

Lines changed: 36 additions & 21 deletions

File tree

openequivariance/implementations/convolution/LoopUnrollConv.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,42 @@ def __init__(self, config, idx_dtype=np.int64,
2626
backward_schedule_type = 3
2727
template = env.get_template("loop_unroll_conv_det.cuh")
2828

29-
self.forward_schedule = ComputationSchedule(self.config,
30-
smem_limit=dp.maxSharedMemPerBlock // 4 * 3, warps_per_block=6,
31-
block_count=dp.multiprocessorCount,
32-
direction = "forward",
33-
irrep_dtype = config.irrep_dtype,
34-
weight_dtype = config.weight_dtype,
35-
schedule_type=forward_schedule_type,
36-
warp_size=dp.warpsize,
37-
include_scratch=self.is_uvw,
38-
stream_weights=self.is_uvw)
39-
40-
self.backward_schedule = ComputationSchedule(self.config,
41-
smem_limit=dp.maxSharedMemPerBlock, warps_per_block=6,
42-
block_count=dp.multiprocessorCount * 2,
43-
direction = "backward",
44-
irrep_dtype = config.irrep_dtype,
45-
weight_dtype = config.weight_dtype,
46-
schedule_type=backward_schedule_type,
47-
warp_size=dp.warpsize,
48-
include_scratch=self.is_uvw,
49-
stream_weights=self.is_uvw)
29+
def generate_forward_schedule(warps_per_block):
30+
self.forward_schedule = ComputationSchedule(self.config,
31+
smem_limit=dp.maxSharedMemPerBlock // 4 * 3, warps_per_block=warps_per_block,
32+
block_count=dp.multiprocessorCount,
33+
direction = "forward",
34+
irrep_dtype = config.irrep_dtype,
35+
weight_dtype = config.weight_dtype,
36+
schedule_type=forward_schedule_type,
37+
warp_size=dp.warpsize,
38+
include_scratch=self.is_uvw,
39+
stream_weights=self.is_uvw)
40+
41+
def generate_backward_schedule(warps_per_block):
42+
self.backward_schedule = ComputationSchedule(self.config,
43+
smem_limit=dp.maxSharedMemPerBlock, warps_per_block=warps_per_block,
44+
block_count=dp.multiprocessorCount * 2,
45+
direction = "backward",
46+
irrep_dtype = config.irrep_dtype,
47+
weight_dtype = config.weight_dtype,
48+
schedule_type=backward_schedule_type,
49+
warp_size=dp.warpsize,
50+
include_scratch=self.is_uvw,
51+
stream_weights=self.is_uvw)
52+
53+
scheduler_generators = [generate_forward_schedule, generate_backward_schedule]
54+
55+
for generate_schedule in scheduler_generators:
56+
warp_count = 6
57+
while warp_count > 0:
58+
try:
59+
generate_schedule(warp_count)
60+
break
61+
except Exception as e:
62+
warp_count -= 1
63+
if warp_count == 0:
64+
raise RuntimeError("Tensor product schedule generation failed, shared memory inadequate!")
5065

5166
if not deterministic:
5267
for segment in self.forward_schedule.segments:

0 commit comments

Comments
 (0)