Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions src/inplace_abn_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,152 @@ __global__ void backward_kernel(
}
}
}
// Vectorized forward kernel: uses float4 loads/stores when possible (fp32)
template<typename scalar_t, typename accscalar_t, typename prmscalar_t, typename index_t, Activation activation>
__global__ void forward_kernel_vec4(
at::PackedTensorAccessor<scalar_t, 3, at::RestrictPtrTraits, index_t> x,
const at::PackedTensorAccessor<accscalar_t, 1, at::RestrictPtrTraits, index_t> mean_,
const at::PackedTensorAccessor<accscalar_t, 1, at::RestrictPtrTraits, index_t> var,
const at::PackedTensorAccessor<prmscalar_t, 1, at::RestrictPtrTraits, index_t> weight_,
const at::PackedTensorAccessor<prmscalar_t, 1, at::RestrictPtrTraits, index_t> bias_,
float eps_, float activation_param) {

index_t c = blockIdx.x;
if (c >= x.size(1)) return;

// requires scalar_t == float for vectorized path
static_assert(sizeof(float) == sizeof(scalar_t), "vec4 forward assumes float");

accscalar_t eps = static_cast<accscalar_t>(eps_);
accscalar_t mean = mean_[c];
accscalar_t inv_std = accscalar_t(1) / ::sqrt(var[c] + eps);
accscalar_t weight = weight_.size(0) > 0 ? ::abs(static_cast<accscalar_t>(weight_[c])) + eps : accscalar_t(1);
accscalar_t bias = bias_.size(0) > 0 ? static_cast<accscalar_t>(bias_[c]) : accscalar_t(0);

index_t num = x.size(0);
index_t sp = x.size(2);

// Only the launch wrapper should call this kernel when sp % 4 == 0 and pointers are aligned.
// Use grid-stride on batch dimension first and vectorize across spatial dimension by 4.
index_t step = blockDim.y * gridDim.y;
for (index_t n = threadIdx.y + blockIdx.y * blockDim.y; n < num; n += step) {
// reinterpret row as float4*
scalar_t* row_ptr = &x[n][c][0];
float4* row_vec = reinterpret_cast<float4*>(row_ptr);
index_t vec_len = sp / 4; // ensured divisible by 4

for (index_t vi = threadIdx.x; vi < vec_len; vi += blockDim.x) {
float4 v = row_vec[vi];
// Unpack, convert to accscalar_t, apply BN and activation, then store back
float vx0 = v.x, vx1 = v.y, vx2 = v.z, vx3 = v.w;

accscalar_t t0 = weight * (static_cast<accscalar_t>(vx0) - mean) * inv_std + bias;
accscalar_t t1 = weight * (static_cast<accscalar_t>(vx1) - mean) * inv_std + bias;
accscalar_t t2 = weight * (static_cast<accscalar_t>(vx2) - mean) * inv_std + bias;
accscalar_t t3 = weight * (static_cast<accscalar_t>(vx3) - mean) * inv_std + bias;

// apply activation in-place (ActivationFn specialized for scalar_t)
scalar_t out0 = static_cast<scalar_t>(t0);
scalar_t out1 = static_cast<scalar_t>(t1);
scalar_t out2 = static_cast<scalar_t>(t2);
scalar_t out3 = static_cast<scalar_t>(t3);

ActivationFn<scalar_t, activation>::forward(out0, activation_param);
ActivationFn<scalar_t, activation>::forward(out1, activation_param);
ActivationFn<scalar_t, activation>::forward(out2, activation_param);
ActivationFn<scalar_t, activation>::forward(out3, activation_param);

row_vec[vi] = make_float4(out0, out1, out2, out3);
}
}
}
// first warpSum to get one value per thread to one value per warp
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
accscalar_t o_avg = __shfl_xor_sync(FULL_WARP_MASK, avg, 1 << i, WARP_SIZE);
int o_n = __shfl_xor_sync(FULL_WARP_MASK, n, 1 << i, WARP_SIZE);
accscalar_t o_var = __shfl_xor_sync(FULL_WARP_MASK, var_n, 1 << i, WARP_SIZE);

accscalar_t factor = 1.0 / fmaxf(1.0, n + o_n);
var_n += o_var + (avg - o_avg) * (avg - o_avg) * n * o_n * factor;
avg = (n * avg + o_n * o_avg) * factor;
n += o_n;
}
// Launch with blockDim.x = 256, gridDim.x = ceil(chn / blockDim.x)
// Each block reduces a contiguous range of 'chn' channels (e.g. a tile)
template<typename scalar_t, typename index_t>
__global__ void reduce_statistics_kernel_tile(
const at::PackedTensorAccessor<scalar_t, 2, at::RestrictPtrTraits, index_t> all_mean,
const at::PackedTensorAccessor<scalar_t, 2, at::RestrictPtrTraits, index_t> all_var,
const at::PackedTensorAccessor<int64_t, 2, at::RestrictPtrTraits, index_t> all_count,
at::PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits, index_t> mean,
at::PackedTensorAccessor<scalar_t, 1, at::RestrictPtrTraits, index_t> var) {

extern __shared__ double scratch[]; // per-block scratch: layout [local_mean, local_var, local_count] per thread

int num = all_mean.size(0), chn = all_mean.size(1);
int tid = threadIdx.x;
int c = blockIdx.x * blockDim.x + tid;

// Each thread computes aggregated value for channel c over all chunks (num),
// but we tile channels across blocks. If chn >> blockDim.x, each block handles a chunk of channels.
if (c < chn) {
double mean_c = 0;
double var_c = 0;
int64_t count_c = 0;
for (int n = 0; n < num; ++n) {
auto count_n = (int64_t)all_count[n][0];
auto mean_n = (double)all_mean[n][c];
auto var_n_term = (double)all_var[n][c] * count_n;
auto delta = mean_n - mean_c;
auto new_count = count_c + count_n;
if (new_count == 0) continue;
var_c += var_n_term + delta * delta * count_c * count_n / (double)new_count;
mean_c = (count_c * mean_c + count_n * mean_n) / (double)new_count;
count_c = new_count;
}
mean[c] = (scalar_t)mean_c;
var[c] = (scalar_t)(var_c / (double)count_c);
}
}
// host-side function (pseudo)
void launch_forward_kernel(
Tensor x, Tensor mean, Tensor var, Tensor weight, Tensor bias,
float eps, float act_param, cudaStream_t stream) {

int c = x.size(1);
int sp = x.size(2);

dim3 block(/* x threads */, /* y threads */);
dim3 grid(c, /* y dim for batch tiling */);

// prefer vec4 if conditions met (fp32 & sp % 4 == 0 & pointer alignment)
bool can_vec4 = x.scalar_type() == at::kFloat && (sp % 4 == 0) &&
reinterpret_cast<uintptr_t>(x.data_ptr<float>()) % 16 == 0;

if (can_vec4) {
auto func = forward_kernel_vec4<float, float, float, int, Activation::RELU>; // adapt activation enum
// grid/block choose...
func<<<grid, block, 0, stream>>>(
x.packed_accessor32<float, 3, RestrictPtrTraits>(),
mean.packed_accessor32<float,1,RestrictPtrTraits>(),
var.packed_accessor32<float,1,RestrictPtrTraits>(),
weight.packed_accessor32<float,1,RestrictPtrTraits>(),
bias.packed_accessor32<float,1,RestrictPtrTraits>(),
eps, act_param);
} else {
// fallback to the scalar kernel (existing forward_kernel)
}
}
// Helper for timing a single kernel launch
float time_kernel_once(std::function<void()> launch_fn) {
cudaEvent_t s, e;
cudaEventCreate(&s); cudaEventCreate(&e);
cudaEventRecord(s);
launch_fn();
cudaEventRecord(e);
cudaEventSynchronize(e);
float ms = 0.0f;
cudaEventElapsedTime(&ms, s, e);
cudaEventDestroy(s); cudaEventDestroy(e);
return ms;
}