Skip to content

Commit 4dfd46e

Browse files
committed
feat(compute): add fused PatchTST encoder layer CUDA kernels
Implement fused encoder forward and backward pass orchestrators that replace ~78 discrete Engine operations per layer with a single C function call. The orchestrator launches cuBLAS GEMMs for matrix multiplications and custom CUDA sub-kernels for LayerNorm, head transpose, GELU, softmax, and residual operations. Forward kernel (fused_encoder_fwd.cu): - 7 sub-kernels: layernorm_fwd, bias_add, head_split, head_merge, softmax_fwd, bias_gelu_fwd, bias_residual_add - 8 cuBLAS Sgemm/SgemmStridedBatched calls - Caches 16 intermediate buffers (FEB_*) for backward use Backward kernel (fused_encoder_bwd.cu): - 6 sub-kernels: layernorm_bwd, gelu_bwd, softmax_bwd, bias_grad_reduce, add, head_split/merge - 14 cuBLAS calls for weight/input gradient computation - Reads forward cache; accumulates into gradient buffers Go bindings: - Purego and CGo wrappers for both kernels - KernelLib symbol registration (optional, non-fatal if absent) - KernelRunner interface methods + all backend stubs - FusedEncoderProvider optional interface on GPUEngine - cublas.Handle.Ptr() for passing raw handle to C Build: Makefile adds -lcublas to libkernels.so link step. Closes zerfoo/zerfoo E55 tasks T55.1.1, T55.1.2, T55.2.1, T55.2.2.
1 parent 34bfe35 commit 4dfd46e

21 files changed

Lines changed: 2036 additions & 2 deletions

compute/fused_encoder.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package compute
2+
3+
import "unsafe"
4+
5+
// FusedEncoderProvider is implemented by engines that support fused PatchTST
6+
// encoder layer forward and backward passes. The fused kernel replaces ~78
7+
// discrete engine operations per layer with a single orchestrated call,
8+
// using cuBLAS for GEMMs and custom CUDA sub-kernels for LayerNorm, GELU,
9+
// softmax, head transpose, and residual operations.
10+
//
11+
// Callers must pre-allocate all buffer arrays and pass device pointers.
12+
// Buffer index constants (FEW_*, FEB_*, FEG_*, etc.) are defined in
13+
// internal/cuda/kernels/fused_encoder_fwd_purego.go and fused_encoder_bwd_purego.go.
14+
//
15+
// This API is not covered by the v1 stability guarantee.
16+
type FusedEncoderProvider interface {
17+
// FusedEncoderAvailable returns true if the fused encoder kernel is loaded.
18+
FusedEncoderAvailable() bool
19+
20+
// FusedEncoderForward executes one encoder layer forward pass.
21+
// weights: [16]unsafe.Pointer to layer weights.
22+
// bufs: [16]unsafe.Pointer to pre-allocated forward cache buffers.
23+
// input/output: [totalRows, dModel] device pointers.
24+
FusedEncoderForward(
25+
weights *[16]unsafe.Pointer,
26+
bufs *[16]unsafe.Pointer,
27+
input, output unsafe.Pointer,
28+
totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int,
29+
) error
30+
31+
// FusedEncoderBackward computes all gradients for one encoder layer.
32+
// weights: [16]unsafe.Pointer to layer weights.
33+
// weightT: [6]unsafe.Pointer to pre-transposed weights.
34+
// fwdBufs: [16]unsafe.Pointer to forward cache (from FusedEncoderForward).
35+
// bwdBufs: [15]unsafe.Pointer to backward scratch buffers.
36+
// grads: [16]unsafe.Pointer to gradient accumulators (accumulated, not zeroed).
37+
// dOutput: upstream gradient; dInput: output gradient; input: original layer input.
38+
FusedEncoderBackward(
39+
weights *[16]unsafe.Pointer,
40+
weightT *[6]unsafe.Pointer,
41+
fwdBufs *[16]unsafe.Pointer,
42+
bwdBufs *[15]unsafe.Pointer,
43+
grads *[16]unsafe.Pointer,
44+
dOutput, dInput, input unsafe.Pointer,
45+
totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int,
46+
) error
47+
}

compute/gpu_fused_encoder.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package compute
2+
3+
import (
4+
"fmt"
5+
"unsafe"
6+
7+
"github.com/zerfoo/ztensor/internal/cublas"
8+
)
9+
10+
// blasHandlePtr extracts the raw cuBLAS handle pointer from the BLAS interface.
11+
// Returns nil if the BLAS is not backed by cuBLAS.
12+
func blasHandlePtr(b interface{}) unsafe.Pointer {
13+
type handleProvider interface {
14+
Handle() *cublas.Handle
15+
}
16+
if hp, ok := b.(handleProvider); ok {
17+
h := hp.Handle()
18+
if h != nil {
19+
return h.Ptr()
20+
}
21+
}
22+
return nil
23+
}
24+
25+
// FusedEncoderAvailable returns true if the fused encoder kernel is loaded
26+
// and the engine has a cuBLAS handle to pass to it.
27+
func (e *GPUEngine[T]) FusedEncoderAvailable() bool {
28+
return e.kernels.FusedEncoderFwdAvailable() && blasHandlePtr(e.blas) != nil
29+
}
30+
31+
// FusedEncoderForward executes one fused encoder layer forward pass.
32+
func (e *GPUEngine[T]) FusedEncoderForward(
33+
weights *[16]unsafe.Pointer,
34+
bufs *[16]unsafe.Pointer,
35+
input, output unsafe.Pointer,
36+
totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int,
37+
) error {
38+
h := blasHandlePtr(e.blas)
39+
if h == nil {
40+
return fmt.Errorf("FusedEncoderForward: cuBLAS handle not available")
41+
}
42+
e.setDevice()
43+
return e.kernels.FusedEncoderFwdF32(h, weights, bufs, input, output,
44+
totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches, e.stream)
45+
}
46+
47+
// FusedEncoderBackward computes all gradients for one fused encoder layer.
48+
func (e *GPUEngine[T]) FusedEncoderBackward(
49+
weights *[16]unsafe.Pointer,
50+
weightT *[6]unsafe.Pointer,
51+
fwdBufs *[16]unsafe.Pointer,
52+
bwdBufs *[15]unsafe.Pointer,
53+
grads *[16]unsafe.Pointer,
54+
dOutput, dInput, input unsafe.Pointer,
55+
totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int,
56+
) error {
57+
h := blasHandlePtr(e.blas)
58+
if h == nil {
59+
return fmt.Errorf("FusedEncoderBackward: cuBLAS handle not available")
60+
}
61+
e.setDevice()
62+
// The KernelRunner interface uses *[16] for weightT, but we have *[6].
63+
// Convert via unsafe pointer.
64+
var wt16 [16]unsafe.Pointer
65+
copy(wt16[:6], weightT[:])
66+
return e.kernels.FusedEncoderBwdF32(h, weights, &wt16, fwdBufs, bwdBufs, grads,
67+
dOutput, dInput, input,
68+
totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches, e.stream)
69+
}
70+
71+
// Compile-time check that GPUEngine implements FusedEncoderProvider.
72+
var _ FusedEncoderProvider = (*GPUEngine[float32])(nil)

internal/cublas/cublas_purego.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ type Handle struct {
107107
ptr uintptr // cublasHandle_t is a pointer
108108
}
109109

110+
// Ptr returns the raw cuBLAS handle pointer for passing to C functions
111+
// (e.g., the fused encoder kernel orchestrator).
112+
func (h *Handle) Ptr() unsafe.Pointer { return unsafe.Pointer(h.ptr) }
113+
110114
// CreateHandle creates a new cuBLAS context handle.
111115
func CreateHandle() (*Handle, error) {
112116
lib, err := getCublasLib()

internal/cuda/kernels/Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ ifeq ($(CUDA_ARCH),sm_121)
1111
NVCC_FLAGS += -DFLASH_BLOCK_SIZE=64
1212
endif
1313

14-
SRCS = counter.cu dequant_q4k.cu dequant_q5_0.cu dequant_q5k.cu dequant_q6k.cu elementwise.cu elementwise_fp16.cu flash_attention.cu flash_attention2.cu flash_decode.cu fp4_gemv.cu fp8_gemm.cu fp8_ops.cu fused_add_rmsnorm.cu fused_norm_add.cu fused_qk_norm_rope.cu fused_repeat_interleave.cu fused_rope.cu fused_softmax_vmul.cu fused_swiglu.cu gather.cu gather_q8.cu gemm_int8.cu gemm_int4.cu gemm_q4.cu gemm_q8.cu gemv_q4k.cu gemv_q4k_sm121.cu gemv_q5k.cu gemv_q5_0.cu gemv_q6k.cu gemv_warp.cu megakernel_ops.cu offset_memcpy.cu paged_attention.cu ragged_attention.cu rope_select.cu scaled_softmax.cu selective_scan.cu sgemv_m1.cu ternary_gemv.cu transpose.cu rmsnorm.cu argmax.cu
14+
SRCS = counter.cu dequant_q4k.cu dequant_q5_0.cu dequant_q5k.cu dequant_q6k.cu elementwise.cu elementwise_fp16.cu flash_attention.cu flash_attention2.cu flash_decode.cu fp4_gemv.cu fp8_gemm.cu fp8_ops.cu fused_add_rmsnorm.cu fused_encoder_fwd.cu fused_encoder_bwd.cu fused_norm_add.cu fused_qk_norm_rope.cu fused_repeat_interleave.cu fused_rope.cu fused_softmax_vmul.cu fused_swiglu.cu gather.cu gather_q8.cu gemm_int8.cu gemm_int4.cu gemm_q4.cu gemm_q8.cu gemv_q4k.cu gemv_q4k_sm121.cu gemv_q5k.cu gemv_q5_0.cu gemv_q6k.cu gemv_warp.cu megakernel_ops.cu offset_memcpy.cu paged_attention.cu ragged_attention.cu rope_select.cu scaled_softmax.cu selective_scan.cu sgemv_m1.cu ternary_gemv.cu transpose.cu rmsnorm.cu argmax.cu
1515
OBJS = $(SRCS:.cu=.o)
1616
PIC_OBJS = $(SRCS:.cu=.pic.o)
1717
LIB = libkernels.a
@@ -27,7 +27,7 @@ $(LIB): $(OBJS)
2727
ar rcs $@ $^
2828

2929
$(SO): $(PIC_OBJS)
30-
$(NVCC) -shared -o $@ $^
30+
$(NVCC) -shared -o $@ $^ -lcublas
3131

3232
# Limit register pressure for kernels that benefit from higher occupancy.
3333
# gemm_q4: 40->32 regs/thread, no spills, occupancy 75%->100% (256-thread blocks).

0 commit comments

Comments
 (0)