@@ -26,27 +26,42 @@ def __init__(self, config, idx_dtype=np.int64,
2626 backward_schedule_type = 3
2727 template = env .get_template ("loop_unroll_conv_det.cuh" )
2828
29- self .forward_schedule = ComputationSchedule (self .config ,
30- smem_limit = dp .maxSharedMemPerBlock // 4 * 3 , warps_per_block = 6 ,
31- block_count = dp .multiprocessorCount ,
32- direction = "forward" ,
33- irrep_dtype = config .irrep_dtype ,
34- weight_dtype = config .weight_dtype ,
35- schedule_type = forward_schedule_type ,
36- warp_size = dp .warpsize ,
37- include_scratch = self .is_uvw ,
38- stream_weights = self .is_uvw )
39-
40- self .backward_schedule = ComputationSchedule (self .config ,
41- smem_limit = dp .maxSharedMemPerBlock , warps_per_block = 6 ,
42- block_count = dp .multiprocessorCount * 2 ,
43- direction = "backward" ,
44- irrep_dtype = config .irrep_dtype ,
45- weight_dtype = config .weight_dtype ,
46- schedule_type = backward_schedule_type ,
47- warp_size = dp .warpsize ,
48- include_scratch = self .is_uvw ,
49- stream_weights = self .is_uvw )
29+ def generate_forward_schedule (warps_per_block ):
30+ self .forward_schedule = ComputationSchedule (self .config ,
31+ smem_limit = dp .maxSharedMemPerBlock // 4 * 3 , warps_per_block = warps_per_block ,
32+ block_count = dp .multiprocessorCount ,
33+ direction = "forward" ,
34+ irrep_dtype = config .irrep_dtype ,
35+ weight_dtype = config .weight_dtype ,
36+ schedule_type = forward_schedule_type ,
37+ warp_size = dp .warpsize ,
38+ include_scratch = self .is_uvw ,
39+ stream_weights = self .is_uvw )
40+
41+ def generate_backward_schedule (warps_per_block ):
42+ self .backward_schedule = ComputationSchedule (self .config ,
43+ smem_limit = dp .maxSharedMemPerBlock , warps_per_block = warps_per_block ,
44+ block_count = dp .multiprocessorCount * 2 ,
45+ direction = "backward" ,
46+ irrep_dtype = config .irrep_dtype ,
47+ weight_dtype = config .weight_dtype ,
48+ schedule_type = backward_schedule_type ,
49+ warp_size = dp .warpsize ,
50+ include_scratch = self .is_uvw ,
51+ stream_weights = self .is_uvw )
52+
53+ scheduler_generators = [generate_forward_schedule , generate_backward_schedule ]
54+
55+ for generate_schedule in scheduler_generators :
56+ warp_count = 6
57+ while warp_count > 0 :
58+ try :
59+ generate_schedule (warp_count )
60+ break
61+ except Exception as e :
62+ warp_count -= 1
63+ if warp_count == 0 :
64+ raise RuntimeError ("Tensor product schedule generation failed, shared memory inadequate!" )
5065
5166 if not deterministic :
5267 for segment in self .forward_schedule .segments :
0 commit comments