diff --git a/src/inplace_abn_kernels.cuh b/src/inplace_abn_kernels.cuh index 6545164..5cb6d57 100644 --- a/src/inplace_abn_kernels.cuh +++ b/src/inplace_abn_kernels.cuh @@ -260,3 +260,152 @@ __global__ void backward_kernel( } } } +// Vectorized forward kernel: uses float4 loads/stores when possible (fp32) +template +__global__ void forward_kernel_vec4( + at::PackedTensorAccessor x, + const at::PackedTensorAccessor mean_, + const at::PackedTensorAccessor var, + const at::PackedTensorAccessor weight_, + const at::PackedTensorAccessor bias_, + float eps_, float activation_param) { + + index_t c = blockIdx.x; + if (c >= x.size(1)) return; + + // requires scalar_t == float for vectorized path + static_assert(sizeof(float) == sizeof(scalar_t), "vec4 forward assumes float"); + + accscalar_t eps = static_cast(eps_); + accscalar_t mean = mean_[c]; + accscalar_t inv_std = accscalar_t(1) / ::sqrt(var[c] + eps); + accscalar_t weight = weight_.size(0) > 0 ? ::abs(static_cast(weight_[c])) + eps : accscalar_t(1); + accscalar_t bias = bias_.size(0) > 0 ? static_cast(bias_[c]) : accscalar_t(0); + + index_t num = x.size(0); + index_t sp = x.size(2); + + // Only the launch wrapper should call this kernel when sp % 4 == 0 and pointers are aligned. + // Use grid-stride on batch dimension first and vectorize across spatial dimension by 4. + index_t step = blockDim.y * gridDim.y; + for (index_t n = threadIdx.y + blockIdx.y * blockDim.y; n < num; n += step) { + // reinterpret row as float4* + scalar_t* row_ptr = &x[n][c][0]; + float4* row_vec = reinterpret_cast(row_ptr); + index_t vec_len = sp / 4; // ensured divisible by 4 + + for (index_t vi = threadIdx.x; vi < vec_len; vi += blockDim.x) { + float4 v = row_vec[vi]; + // Unpack, convert to accscalar_t, apply BN and activation, then store back + float vx0 = v.x, vx1 = v.y, vx2 = v.z, vx3 = v.w; + + accscalar_t t0 = weight * (static_cast(vx0) - mean) * inv_std + bias; + accscalar_t t1 = weight * (static_cast(vx1) - mean) * inv_std + bias; + accscalar_t t2 = weight * (static_cast(vx2) - mean) * inv_std + bias; + accscalar_t t3 = weight * (static_cast(vx3) - mean) * inv_std + bias; + + // apply activation in-place (ActivationFn specialized for scalar_t) + scalar_t out0 = static_cast(t0); + scalar_t out1 = static_cast(t1); + scalar_t out2 = static_cast(t2); + scalar_t out3 = static_cast(t3); + + ActivationFn::forward(out0, activation_param); + ActivationFn::forward(out1, activation_param); + ActivationFn::forward(out2, activation_param); + ActivationFn::forward(out3, activation_param); + + row_vec[vi] = make_float4(out0, out1, out2, out3); + } + } +} + // first warpSum to get one value per thread to one value per warp + for (int i = 0; i < getMSB(WARP_SIZE); ++i) { + accscalar_t o_avg = __shfl_xor_sync(FULL_WARP_MASK, avg, 1 << i, WARP_SIZE); + int o_n = __shfl_xor_sync(FULL_WARP_MASK, n, 1 << i, WARP_SIZE); + accscalar_t o_var = __shfl_xor_sync(FULL_WARP_MASK, var_n, 1 << i, WARP_SIZE); + + accscalar_t factor = 1.0 / fmaxf(1.0, n + o_n); + var_n += o_var + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; + avg = (n * avg + o_n * o_avg) * factor; + n += o_n; + } +// Launch with blockDim.x = 256, gridDim.x = ceil(chn / blockDim.x) +// Each block reduces a contiguous range of 'chn' channels (e.g. a tile) +template +__global__ void reduce_statistics_kernel_tile( + const at::PackedTensorAccessor all_mean, + const at::PackedTensorAccessor all_var, + const at::PackedTensorAccessor all_count, + at::PackedTensorAccessor mean, + at::PackedTensorAccessor var) { + + extern __shared__ double scratch[]; // per-block scratch: layout [local_mean, local_var, local_count] per thread + + int num = all_mean.size(0), chn = all_mean.size(1); + int tid = threadIdx.x; + int c = blockIdx.x * blockDim.x + tid; + + // Each thread computes aggregated value for channel c over all chunks (num), + // but we tile channels across blocks. If chn >> blockDim.x, each block handles a chunk of channels. + if (c < chn) { + double mean_c = 0; + double var_c = 0; + int64_t count_c = 0; + for (int n = 0; n < num; ++n) { + auto count_n = (int64_t)all_count[n][0]; + auto mean_n = (double)all_mean[n][c]; + auto var_n_term = (double)all_var[n][c] * count_n; + auto delta = mean_n - mean_c; + auto new_count = count_c + count_n; + if (new_count == 0) continue; + var_c += var_n_term + delta * delta * count_c * count_n / (double)new_count; + mean_c = (count_c * mean_c + count_n * mean_n) / (double)new_count; + count_c = new_count; + } + mean[c] = (scalar_t)mean_c; + var[c] = (scalar_t)(var_c / (double)count_c); + } +} +// host-side function (pseudo) +void launch_forward_kernel( + Tensor x, Tensor mean, Tensor var, Tensor weight, Tensor bias, + float eps, float act_param, cudaStream_t stream) { + + int c = x.size(1); + int sp = x.size(2); + + dim3 block(/* x threads */, /* y threads */); + dim3 grid(c, /* y dim for batch tiling */); + + // prefer vec4 if conditions met (fp32 & sp % 4 == 0 & pointer alignment) + bool can_vec4 = x.scalar_type() == at::kFloat && (sp % 4 == 0) && + reinterpret_cast(x.data_ptr()) % 16 == 0; + + if (can_vec4) { + auto func = forward_kernel_vec4; // adapt activation enum + // grid/block choose... + func<<>>( + x.packed_accessor32(), + mean.packed_accessor32(), + var.packed_accessor32(), + weight.packed_accessor32(), + bias.packed_accessor32(), + eps, act_param); + } else { + // fallback to the scalar kernel (existing forward_kernel) + } +} +// Helper for timing a single kernel launch +float time_kernel_once(std::function launch_fn) { + cudaEvent_t s, e; + cudaEventCreate(&s); cudaEventCreate(&e); + cudaEventRecord(s); + launch_fn(); + cudaEventRecord(e); + cudaEventSynchronize(e); + float ms = 0.0f; + cudaEventElapsedTime(&ms, s, e); + cudaEventDestroy(s); cudaEventDestroy(e); + return ms; +}