Skip to content

Commit 7fd7ab6

Browse files
committed
Fixed a regression.
1 parent 161a1b6 commit 7fd7ab6

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

openequivariance/openequivariance/templates/loop_unroll_tp.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{%- from 'macros.jinja' import layout_load, layout_store with context %}
1+
{%- from 'macros.jinja' import layout_load, layout_store, reg_store with context %}
22
{%- from 'wmm.cuh' import generate_matmul %}
33

44
{%- macro generate_segment_kernel_forward(id, segment, warp_size) %}
@@ -248,7 +248,7 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
248248
{%- endif %}
249249
{%- endfor %}
250250

251-
{{ layout_store(problem.layout, L1[u].mul, L3[w].ir.dim, "scratch", "0", "l3_grad", "=", 1.0) }}
251+
{{ reg_store(L1[u].mul, L3[w].ir.dim, "scratch", "0", "l3_grad", "=", 1.0) }}
252252

253253
__syncwarp();
254254
{{matmul_basename}}B_{{id}}_{{k}}(L3_grad_smem + offset, scratch, weights_smem);

0 commit comments

Comments
 (0)