Skip to content

Commit c78d48f

Browse files
committed
Avoided transposing irreps once the shared memory load is complete.
1 parent 33ed045 commit c78d48f

4 files changed

Lines changed: 53 additions & 29 deletions

File tree

openequivariance/openequivariance/core/LoopUnrollTP.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1-
import numpy as np
21
import json
32

4-
from openequivariance.templates.jinja_utils import get_jinja_environment
5-
from openequivariance.core.ComputationSchedule import ComputationSchedule
6-
from openequivariance.core.TensorProductBase import TensorProductBase
7-
from openequivariance.benchmark.logging_utils import getLogger
8-
from openequivariance.core.utils import dtype_to_enum, hash_str_64
3+
import numpy as np
94

5+
from openequivariance.benchmark.logging_utils import getLogger
6+
from openequivariance.core.ComputationSchedule import (
7+
ComputationSchedule,
8+
SMEMCapacityException,
9+
)
10+
from openequivariance.core.TensorProductBase import TensorProductBase
1011
from openequivariance.core.utils import (
11-
filter_and_analyze_problem,
1212
count_cg_non_zero,
13+
dtype_to_enum,
14+
filter_and_analyze_problem,
15+
hash_str_64,
1316
)
17+
from openequivariance.templates.jinja_utils import get_jinja_environment
1418

1519
logger = getLogger()
1620

@@ -80,12 +84,14 @@ def generate_double_backward_schedule(warps_per_block):
8084
try:
8185
generate_schedule(warp_count)
8286
break
83-
except Exception:
87+
except SMEMCapacityException:
8488
warp_count -= 2
8589
if warp_count == 0:
8690
raise RuntimeError(
8791
"Tensor product schedule generation failed, shared memory inadequate!"
8892
)
93+
except Exception:
94+
raise
8995

9096
self.jit_kernel = postprocess_kernel(
9197
template.render(

openequivariance/openequivariance/core/e3nn_lite.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,15 @@
3636
SOFTWARE.
3737
"""
3838

39-
import itertools
40-
from typing import Tuple, NamedTuple, Union, List, Any, Optional
41-
from math import sqrt, prod
4239
import collections
40+
import copy
41+
import functools
42+
import itertools
43+
from math import prod, sqrt
44+
from typing import Any, List, NamedTuple, Optional, Tuple, Union
45+
4346
import numpy as np
4447
import numpy.linalg as la
45-
import functools
46-
import copy
4748

4849

4950
def perm_inverse(p):
@@ -412,7 +413,7 @@ def __init__(
412413
label: Optional[str] = None,
413414
irrep_dtype: type[np.generic] = np.float32,
414415
weight_dtype: type[np.generic] = np.float32,
415-
layout: str = "mul_ir"
416+
layout: str = "mul_ir",
416417
) -> None:
417418
# === Setup ===
418419
super().__init__()
@@ -434,6 +435,7 @@ def __init__(
434435
self.irrep_normalization = irrep_normalization
435436
self.path_normalization = path_normalization
436437
self.label = label if label is not None else ""
438+
self.layout = layout
437439
del irreps_in1, irreps_in2, irreps_out
438440

439441
instructions = [x if len(x) == 6 else x + (1.0,) for x in instructions]

openequivariance/openequivariance/templates/loop_unroll_tp.cuh

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{%- from 'macros.jinja' import transpose_load, transpose_store, reg_store with context %}
1+
{%- from 'macros.jinja' import layout_load, layout_store with context %}
22
{%- from 'wmm.cuh' import generate_matmul %}
33

44
{%- macro generate_segment_kernel_forward(id, segment, warp_size) %}
@@ -36,7 +36,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
3636

3737
{%- if k == 0 or interactions[k][0] != interactions[k-1][0] %}
3838
offset = {{ L1.slices()[u].start}};
39-
{{transpose_load(L1[u].mul, L1[u].ir.dim, 'L1_smem', 'offset', 'l1_vec')}}
39+
{{layout_load(problem.layout, L1[u].mul, L1[u].ir.dim, 'L1_smem', 'offset', 'l1_vec')}}
4040
{%- endif %}
4141

4242
#pragma unroll
@@ -72,7 +72,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
7272
// ----------------- CORE CALCULATION -----------------
7373

7474
{%- if problem.instructions[k].connection_mode == "uvw" %}
75-
{{transpose_store(L1[u].mul, L3[w].ir.dim, 'scratch', '0', 'l3_vec', '=', '1.0')}}
75+
{{layout_store(problem.layout, L1[u].mul, L3[w].ir.dim, 'scratch', '0', 'l3_vec', '=', '1.0')}}
7676
__syncwarp();
7777
offset = {{ L3.slices()[w].start}};
7878
matmul_fwd_{{id}}_{{k}}(weights_smem, scratch, L3_smem + offset);
@@ -85,7 +85,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
8585

8686
{%- if problem.instructions[k].connection_mode != "uvw" %}
8787
offset = {{ L3.slices()[w].start}};
88-
{{transpose_store(L3[w].mul, L3[w].ir.dim, 'L3_smem', 'offset', 'l3_vec', '+=', '1.0')}}
88+
{{layout_store(problem.layout, L3[w].mul, L3[w].ir.dim, 'L3_smem', 'offset', 'l3_vec', '+=', '1.0')}}
8989

9090
{%- if L2[v].mul > 1%}
9191
#pragma unroll
@@ -168,15 +168,15 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
168168

169169
{%- if k == 0 or interactions[k][0] != interactions[k-1][0] %}
170170
offset = {{ L1.slices()[u].start}};
171-
{{transpose_load(L1[u].mul, L1[u].ir.dim, 'L1_smem', 'offset', 'l1_vec')}}
172-
{{transpose_load(L1[u].mul, L1[u].ir.dim, 'L1_grad_smem', 'offset', 'l1_grad')}}
171+
{{layout_load(problem.layout, L1[u].mul, L1[u].ir.dim, 'L1_smem', 'offset', 'l1_vec')}}
172+
{{layout_load(problem.layout, L1[u].mul, L1[u].ir.dim, 'L1_grad_smem', 'offset', 'l1_grad')}}
173173
{%- endif %}
174174

175175

176176
{%- if problem.instructions[k].connection_mode != "uvw" %}
177177
{%- if k == 0 or interactions[k][2] != interactions[k-1][2] %}
178178
offset = {{ L3.slices()[w].start}};
179-
{{transpose_load(L3[w].mul, L3[w].ir.dim, 'L3_grad_smem', 'offset', 'l3_grad')}}
179+
{{layout_load(problem.layout, L3[w].mul, L3[w].ir.dim, 'L3_grad_smem', 'offset', 'l3_grad')}}
180180
{%- endif %}
181181
{%- endif %}
182182

@@ -225,7 +225,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
225225
{{matmul_basename}}A_{{id}}_{{k}}(weights_smem, L3_grad_smem + offset, scratch);
226226
__syncwarp();
227227

228-
{{transpose_load(L1[u].mul, L3[w].ir.dim, 'scratch', '0', 'l3_grad')}}
228+
{{layout_load(problem.layout, L1[u].mul, L3[w].ir.dim, 'scratch', '0', 'l3_grad')}}
229229

230230
{%- for i in range(tensor.nnz) %}
231231
{%- set coord1, coord2, coord3, value = tensor.tuples[i] %}
@@ -248,7 +248,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
248248
{%- endif %}
249249
{%- endfor %}
250250

251-
{{ reg_store(L1[u].mul, L3[w].ir.dim, "scratch", "0", "l3_grad", "=", 1.0) }}
251+
{{ layout_store(problem.layout, L1[u].mul, L3[w].ir.dim, "scratch", "0", "l3_grad", "=", 1.0) }}
252252

253253
__syncwarp();
254254
{{matmul_basename}}B_{{id}}_{{k}}(L3_grad_smem + offset, scratch, weights_smem);
@@ -305,7 +305,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
305305
// Storeback
306306
{%- if k == num_interact - 1 or interactions[k][0] != interactions[k+1][0] %}
307307
offset = {{ L1.slices()[u].start}};
308-
{{transpose_store(L1[u].mul, L1[u].ir.dim, 'L1_grad_smem', 'offset', 'l1_grad', '=', '1.0')}}
308+
{{layout_store(problem.layout, L1[u].mul, L1[u].ir.dim, 'L1_grad_smem', 'offset', 'l1_grad', '=', '1.0')}}
309309
{%- endif %}
310310

311311
{%- endfor %}

openequivariance/openequivariance/templates/macros.jinja

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,22 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray.
5050
}
5151
{%- endmacro %}
5252

53+
{%- macro layout_load(layout, mul, dim, smem, offset, reg) %}
54+
{%- if layout == "ir_mul" %}
55+
{{ reg_load(mul, dim, smem, offset, reg) }}
56+
{%- else %}
57+
{{ transpose_load(mul, dim, smem, offset, reg) }}
58+
{%- endif %}
59+
{%- endmacro %}
60+
61+
{%- macro layout_store(layout, mul, dim, smem, offset, reg, op, coeff) %}
62+
{%- if layout == "ir_mul" %}
63+
{{ reg_store(mul, dim, smem, offset, reg, op, coeff) }}
64+
{%- else %}
65+
{{ transpose_store(mul, dim, smem, offset, reg, op, coeff) }}
66+
{%- endif %}
67+
{%- endmacro %}
68+
5369
{%- macro declare_smem_variables(segment, smem_base) %}
5470
{%- for name in segment.smem %}
5571
{%- if name != "total" %}
@@ -75,7 +91,7 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray.
7591
{%- set dim = src_mul_ir.ir.dim %}
7692
{%- set mul = src_mul_ir.mul %}
7793
{%- for i in range(dim) %}
78-
ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];)
94+
ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];)
7995
{%- endfor %}
8096
{%- endfor %}
8197
{%- endif %}
@@ -97,7 +113,7 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray.
97113
{%- set dim = src_mul_ir.ir.dim %}
98114
{%- set mul = src_mul_ir.mul %}
99115
{%- for i in range(dim) %}
100-
ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];)
116+
ROW_OPERATION({{mul}}, {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id] = {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}];)
101117
{%- endfor %}
102118
{%- endfor %}
103119
{%- endif %}
@@ -128,15 +144,15 @@ Keys map to lists of tuples with (name, dtype, num_elements) of each subarray.
128144
{%- set mul = src_mul_ir.mul %}
129145
{%- if map.storeback_procedure[idx] == "write" %}
130146
{%- for i in range(dim) %}
131-
ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] = {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id];)
147+
ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] = {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id];)
132148
{%- endfor %}
133149
{%- elif map.storeback_procedure[idx] == "accumulate" %}
134150
{%- for i in range(dim) %}
135-
ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] += {{smem_ptr}}[{{dst_rng.start + i * mul}} + {{loop_var}} + lane_id];)
151+
ROW_OPERATION({{mul}}, {{loop_var}}, {{glb_ptr_shft}}[{{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}] += {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id];)
136152
{%- endfor %}
137153
{%- elif map.storeback_procedure[idx] == "atomic_accumulate" %}
138154
{%- for i in range(dim) %}
139-
ROW_OPERATION({{mul}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + i * mul}} + lane_id + {{loop_var}}]);)
155+
ROW_OPERATION({{mul}}, {{loop_var}}, atomicAdd({{glb_ptr_shft}} + {{src_view.ir_mul_offset + i * src_view.ir_mul_stride}} + {{loop_var}}, {{smem_ptr}}[{{dst_rng.start + loop_var + i * mul}} + lane_id]);)
140156
{%- endfor %}
141157
{%- endif %}
142158
{%- endfor %}

0 commit comments

Comments
 (0)