Skip to content

Commit 716bbd6

Browse files
committed
fix(kernels): rename kernel_add in fused_encoder_bwd to avoid symbol clash
The fused encoder backward file defined kernel_add which clashes with the same symbol in elementwise.cu when linking into libkernels.so. Rename to kernel_enc_bwd_add to avoid the duplicate symbol error.
1 parent d7bb08e commit 716bbd6

1 file changed

Lines changed: 5 additions & 5 deletions

File tree

internal/cuda/kernels/fused_encoder_bwd.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
* kernel_gelu_bwd GELU derivative * upstream gradient
1010
* kernel_softmax_bwd Softmax backward (Jacobian-vector product)
1111
* kernel_bias_grad_reduce Sum rows to compute bias gradients
12-
* kernel_add_elementwise Element-wise addition for residual gradients
12+
* kernel_enc_bwd_add_elementwise Element-wise addition for residual gradients
1313
* kernel_matmul_grad_accum Accumulate weight gradient: dW += A^T @ B
1414
*
1515
* cuBLAS calls (~14 total per layer):
@@ -313,7 +313,7 @@ __global__ void kernel_bias_grad_reduce(
313313
/* out[i] = a[i] + b[i] */
314314
/* ------------------------------------------------------------------ */
315315

316-
__global__ void kernel_add(
316+
__global__ void kernel_enc_bwd_add(
317317
const float* __restrict__ a,
318318
const float* __restrict__ b,
319319
float* __restrict__ out,
@@ -330,7 +330,7 @@ __global__ void kernel_add(
330330
/* out[i] = a[i] + b[i] + c[i] */
331331
/* ------------------------------------------------------------------ */
332332

333-
__global__ void kernel_add3(
333+
__global__ void kernel_enc_bwd_add3(
334334
const float* __restrict__ a,
335335
const float* __restrict__ b,
336336
const float* __restrict__ c,
@@ -558,7 +558,7 @@ cudaError_t fused_encoder_bwd_f32(
558558
dXRes1, dg_norm2W, dg_norm2B, dModel);
559559

560560
/* Add residual skip: dXRes1 += dOutput */
561-
kernel_add<<<elemGridTrDm, block256, 0, stream>>>(
561+
kernel_enc_bwd_add<<<elemGridTrDm, block256, 0, stream>>>(
562562
dXRes1, dOutput, dXRes1, trDm);
563563

564564
/* ------------------------------------------------------------ */
@@ -722,7 +722,7 @@ cudaError_t fused_encoder_bwd_f32(
722722
temp, dg_norm1W, dg_norm1B, dModel);
723723

724724
/* dInput = temp + dXRes1 */
725-
kernel_add<<<elemGridTrDm, block256, 0, stream>>>(
725+
kernel_enc_bwd_add<<<elemGridTrDm, block256, 0, stream>>>(
726726
temp, dXRes1, dInput, trDm);
727727

728728
return cudaGetLastError();

0 commit comments

Comments
 (0)