1- {%- from ' macros.jinja' import transpose_load, transpose_store, reg_store with context %}
1+ {%- from ' macros.jinja' import layout_load, layout_store with context %}
22{%- from ' wmm.cuh' import generate_matmul %}
33
44{%- macro generate_segment_kernel_forward (id, segment, warp_size) %}
@@ -36,7 +36,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
3636
3737 {%- if k == 0 or interactions[k][0 ] != interactions[k-1 ][0 ] %}
3838 offset = {{ L1.slices ()[u].start }};
39- {{transpose_load ( L1[u].mul , L1[u].ir .dim , ' L1_smem' , ' offset' , ' l1_vec' )}}
39+ {{layout_load (problem. layout , L1[u].mul , L1[u].ir .dim , ' L1_smem' , ' offset' , ' l1_vec' )}}
4040 {%- endif %}
4141
4242 #pragma unroll
@@ -72,7 +72,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
7272 // ----------------- CORE CALCULATION -----------------
7373
7474 {%- if problem.instructions [k].connection_mode == " uvw" %}
75- {{transpose_store ( L1[u].mul , L3[w].ir .dim , ' scratch' , ' 0' , ' l3_vec' , ' =' , ' 1.0' )}}
75+ {{layout_store (problem. layout , L1[u].mul , L3[w].ir .dim , ' scratch' , ' 0' , ' l3_vec' , ' =' , ' 1.0' )}}
7676 __syncwarp ();
7777 offset = {{ L3.slices ()[w].start }};
7878 matmul_fwd_{{id}}_{{k}}(weights_smem, scratch, L3_smem + offset);
@@ -85,7 +85,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
8585
8686 {%- if problem.instructions [k].connection_mode != " uvw" %}
8787 offset = {{ L3.slices ()[w].start }};
88- {{transpose_store ( L3[w].mul , L3[w].ir .dim , ' L3_smem' , ' offset' , ' l3_vec' , ' +=' , ' 1.0' )}}
88+ {{layout_store (problem. layout , L3[w].mul , L3[w].ir .dim , ' L3_smem' , ' offset' , ' l3_vec' , ' +=' , ' 1.0' )}}
8989
9090 {%- if L2[v].mul > 1 %}
9191 #pragma unroll
@@ -168,15 +168,15 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
168168
169169 {%- if k == 0 or interactions[k][0 ] != interactions[k-1 ][0 ] %}
170170 offset = {{ L1.slices ()[u].start }};
171- {{transpose_load ( L1[u].mul , L1[u].ir .dim , ' L1_smem' , ' offset' , ' l1_vec' )}}
172- {{transpose_load ( L1[u].mul , L1[u].ir .dim , ' L1_grad_smem' , ' offset' , ' l1_grad' )}}
171+ {{layout_load (problem. layout , L1[u].mul , L1[u].ir .dim , ' L1_smem' , ' offset' , ' l1_vec' )}}
172+ {{layout_load (problem. layout , L1[u].mul , L1[u].ir .dim , ' L1_grad_smem' , ' offset' , ' l1_grad' )}}
173173 {%- endif %}
174174
175175
176176 {%- if problem.instructions [k].connection_mode != " uvw" %}
177177 {%- if k == 0 or interactions[k][2 ] != interactions[k-1 ][2 ] %}
178178 offset = {{ L3.slices ()[w].start }};
179- {{transpose_load ( L3[w].mul , L3[w].ir .dim , ' L3_grad_smem' , ' offset' , ' l3_grad' )}}
179+ {{layout_load (problem. layout , L3[w].mul , L3[w].ir .dim , ' L3_grad_smem' , ' offset' , ' l3_grad' )}}
180180 {%- endif %}
181181 {%- endif %}
182182
@@ -225,7 +225,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
225225 {{matmul_basename}}A_{{id}}_{{k}}(weights_smem, L3_grad_smem + offset, scratch);
226226 __syncwarp ();
227227
228- {{transpose_load ( L1[u].mul , L3[w].ir .dim , ' scratch' , ' 0' , ' l3_grad' )}}
228+ {{layout_load (problem. layout , L1[u].mul , L3[w].ir .dim , ' scratch' , ' 0' , ' l3_grad' )}}
229229
230230 {%- for i in range (tensor.nnz ) %}
231231 {%- set coord1, coord2, coord3, value = tensor.tuples [i] %}
@@ -248,7 +248,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
248248 {%- endif %}
249249 {%- endfor %}
250250
251- {{ reg_store ( L1[u].mul , L3[w].ir .dim , " scratch" , " 0" , " l3_grad" , " =" , 1.0 ) }}
251+ {{ layout_store (problem. layout , L1[u].mul , L3[w].ir .dim , " scratch" , " 0" , " l3_grad" , " =" , 1.0 ) }}
252252
253253 __syncwarp ();
254254 {{matmul_basename}}B_{{id}}_{{k}}(L3_grad_smem + offset, scratch, weights_smem);
@@ -305,7 +305,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
305305 // Storeback
306306 {%- if k == num_interact - 1 or interactions[k][0 ] != interactions[k+1 ][0 ] %}
307307 offset = {{ L1.slices ()[u].start }};
308- {{transpose_store ( L1[u].mul , L1[u].ir .dim , ' L1_grad_smem' , ' offset' , ' l1_grad' , ' =' , ' 1.0' )}}
308+ {{layout_store (problem. layout , L1[u].mul , L1[u].ir .dim , ' L1_grad_smem' , ' offset' , ' l1_grad' , ' =' , ' 1.0' )}}
309309 {%- endif %}
310310
311311 {%- endfor %}
0 commit comments