|
| 1 | +#ifndef KERNELS_H |
| 2 | +#define KERNELS_H |
| 3 | + |
| 4 | +#include "gpu.h" |
| 5 | + |
| 6 | +namespace gpu { |
| 7 | + |
| 8 | + |
| 9 | +static const char *kShaderGelu = R"( |
| 10 | +const GELU_SCALING_FACTOR: f32 = 0.7978845608028654; // sqrt(2.0 / PI) |
| 11 | +@group(0) @binding(0) var<storage, read_write> inp: array<{{precision}}>; |
| 12 | +@group(0) @binding(1) var<storage, read_write> out: array<{{precision}}>; |
| 13 | +@compute @workgroup_size({{workgroupSize}}) |
| 14 | +fn main( |
| 15 | + @builtin(global_invocation_id) GlobalInvocationID: vec3<u32>) { |
| 16 | + let i: u32 = GlobalInvocationID.x; |
| 17 | + if (i < arrayLength(&inp)) { |
| 18 | + let x: f32 = inp[i]; |
| 19 | + // select is more stable for larger values of x |
| 20 | + out[i] = select(0.5 * x * (1.0 + tanh(GELU_SCALING_FACTOR |
| 21 | + * (x + .044715 * x * x * x))), x, x > 10.0); |
| 22 | + } |
| 23 | +} |
| 24 | +)"; |
| 25 | + |
| 26 | + |
| 27 | +static const char *kTanh = R"( |
| 28 | +@group(0) @binding(0) var<storage, read_write> inp: array<{{precision}}>; |
| 29 | +@group(0) @binding(1) var<storage, read_write> out: array<{{precision}}>; |
| 30 | +@compute @workgroup_size({{workgroupSize}}) |
| 31 | +fn main( |
| 32 | + @builtin(global_invocation_id) GlobalInvocationID: vec3<u32>) { |
| 33 | + let i: u32 = GlobalInvocationID.x; |
| 34 | + if (i < arrayLength(&inp)) { |
| 35 | + let x: f32 = inp[i]; |
| 36 | + out[i] = tan(x); |
| 37 | + } |
| 38 | +} |
| 39 | +)"; |
| 40 | + |
| 41 | +static const char *kShaderHadamard = R"( |
| 42 | +@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>; |
| 43 | +@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>; |
| 44 | +@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>; |
| 45 | +@compute @workgroup_size({{workgroupSize}}) |
| 46 | +fn main( |
| 47 | + @builtin(global_invocation_id) GlobalInvocationID: vec3<u32>) { |
| 48 | + let idx = GlobalInvocationID.x; |
| 49 | + if (idx < arrayLength(&A)) { |
| 50 | + C[idx] = A[idx] * B[idx]; |
| 51 | + } |
| 52 | +} |
| 53 | +)"; |
| 54 | + |
| 55 | +static const char *kShaderResidual = R"( |
| 56 | +@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>; |
| 57 | +@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>; |
| 58 | +@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>; |
| 59 | +@compute @workgroup_size({{workgroupSize}}) |
| 60 | +fn main( |
| 61 | + @builtin(global_invocation_id) GlobalInvocationID: vec3<u32>) { |
| 62 | + let idx = GlobalInvocationID.x; |
| 63 | + if (idx < arrayLength(&A)) { |
| 64 | + C[idx] = A[idx] + B[idx]; |
| 65 | + } |
| 66 | +} |
| 67 | +)"; |
| 68 | + |
| 69 | +/* LayerNorm |
| 70 | + * v1: |
| 71 | + * - No caching mean/std for backwards |
| 72 | + * - No parallel reduction |
| 73 | + * - Simple 1 thread for each 1..N |
| 74 | + */ |
| 75 | +// TODO(avh): Allow larger virtual 1D workgroups by making use of y / z |
| 76 | +// dimensions and calculating the threadID accordingly. |
| 77 | +static const char *kShaderLayerNorm1 = R"( |
| 78 | +@group(0) @binding(0) var<storage, read_write> inp: array<{{precision}}>; |
| 79 | +@group(0) @binding(1) var<storage, read_write> weight: array<{{precision}}>; |
| 80 | +@group(0) @binding(2) var<storage, read_write> bias: array<{{precision}}>; |
| 81 | +@group(0) @binding(3) var<storage, read_write> out: array<{{precision}}>; |
| 82 | +@group(0) @binding(4) var<uniform> params: Params; |
| 83 | +
|
| 84 | +struct Params { |
| 85 | + N: u32, |
| 86 | + C: u32, |
| 87 | +}; |
| 88 | +
|
| 89 | +@compute @workgroup_size({{workgroupSize}}) |
| 90 | +fn main(@builtin(global_invocation_id) GlobalInvocationID: vec3<u32>, |
| 91 | + @builtin(local_invocation_id) LocalInvocationID: vec3<u32>, |
| 92 | + @builtin(workgroup_id) WorkgroupID: vec3<u32>) { |
| 93 | + let idx: u32 = GlobalInvocationID.x; |
| 94 | +
|
| 95 | + if (idx >= params.N) { return; } |
| 96 | +
|
| 97 | + let C: u32 = params.C; |
| 98 | +
|
| 99 | + // Calculate mean |
| 100 | + var sum: f32 = 0.0; |
| 101 | + for (var i: u32 = 0; i < C; i = i + 1) { |
| 102 | + sum += inp[idx * C + i]; |
| 103 | + } |
| 104 | + let mean_val: f32 = sum / f32(C); |
| 105 | +
|
| 106 | + // Calculate rstd |
| 107 | + sum = 0.0; |
| 108 | + for (var i: u32 = 0; i < C; i = i + 1) { |
| 109 | + let diff: f32 = inp[idx * C + i] - mean_val; |
| 110 | + sum += diff * diff; |
| 111 | + } |
| 112 | + let rstd_val: f32 = 1.0 / sqrt(sum / f32(C) + 1e-5); |
| 113 | +
|
| 114 | + for (var i: u32 = 0; i < C; i = i + 1) { |
| 115 | + let n: f32 = rstd_val * (inp[idx * C + i] - mean_val); |
| 116 | + out[idx * C + i] = n * weight[i] + bias[i]; |
| 117 | + } |
| 118 | +} |
| 119 | +)"; |
| 120 | + |
| 121 | +// matrix multiplication (naive implementation) |
| 122 | +static const char *kShaderMatMul1 = R"( |
| 123 | +@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>; |
| 124 | +@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>; |
| 125 | +@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>; |
| 126 | +@compute @workgroup_size({{workgroupSize}}) |
| 127 | +fn main( |
| 128 | + @builtin(global_invocation_id) GlobalInvocationID: vec3<u32>) { |
| 129 | + let i: u32 = GlobalInvocationID.x / {{N}}; |
| 130 | + let j: u32 = GlobalInvocationID.x % {{N}}; |
| 131 | + if (i < {{M}} && j < {{N}}) { |
| 132 | + var sum: f32 = 0.0; |
| 133 | + for (var k: u32 = 0; k < {{K}}; k = k + 1) { |
| 134 | + sum = sum + A[i * {{K}} + k] * B[k * {{N}} + j]; |
| 135 | + } |
| 136 | + C[i * {{N}} + j] = sum; |
| 137 | + } |
| 138 | +} |
| 139 | +)"; |
| 140 | + |
| 141 | +static const char *kShaderMatMul2 = R"( |
| 142 | +@group(0) @binding(0) var<storage, read_write> A: array<f32>; |
| 143 | +@group(0) @binding(1) var<storage, read_write> B: array<f32>; |
| 144 | +@group(0) @binding(2) var<storage, read_write> C: array<f32>; |
| 145 | +var<workgroup> tileA: array<f32, workgroupSizeY * workgroupSizeX>; |
| 146 | +var<workgroup> tileB: array<f32, workgroupSizeY * workgroupSizeX>; |
| 147 | +@compute @workgroup_size(workgroupSizeX, workgroupSizeY, 1) |
| 148 | +fn matmul( |
| 149 | + @builtin(global_invocation_id) global_id : vec3<u32>, |
| 150 | + @builtin(local_invocation_id) local_id : vec3<u32>, |
| 151 | + @builtin(workgroup_id) workgroup_id : vec3<u32> |
| 152 | +) { |
| 153 | + let row = global_id.x; |
| 154 | + let col = global_id.y; |
| 155 | + if (row >= {{M}} || col >= {{N}}) { |
| 156 | + return; |
| 157 | + } |
| 158 | + var result: f32 = 0.0; |
| 159 | + for (var i = 0u; i < {{K}}; i = i + workgroupSizeX) { |
| 160 | + // Load tiles into shared memory |
| 161 | + tileA[local_id.y][local_id.x] = A[row][i + local_id.x]; |
| 162 | + tileB[local_id.y][local_id.x] = B[i + local_id.y][col]; |
| 163 | + // Synchronize to make sure the tile is loaded |
| 164 | + workgroupBarrier(); |
| 165 | + // Perform partial dot product for the current tile |
| 166 | + for (var k = 0u; k < workgroupSizeX; k = k + 1u) { |
| 167 | + result = result + tileA[local_id.y][k] * tileB[k][local_id.x]; |
| 168 | + } |
| 169 | + // Synchronize before loading the next tile |
| 170 | + workgroupBarrier(); |
| 171 | + } |
| 172 | + C[row][col] = result; |
| 173 | +} |
| 174 | +)"; |
| 175 | + |
| 176 | +/* Generates KernelCode instance for all matmul kernels - pass in |
| 177 | + * the template code via `shaderRaw`. |
| 178 | + * |
| 179 | + * This is intended to be run ahead of time, so is not performance critical. |
| 180 | + * */ |
| 181 | +KernelCode MatmulShader(size_t workgroupSize, const char *shaderRaw, |
| 182 | + NumType precision, size_t M, size_t K, size_t N) { |
| 183 | + KernelCode shader = {shaderRaw, workgroupSize, precision}; |
| 184 | + replaceAll(shader.data, "{{M}}", std::to_string(M)); |
| 185 | + replaceAll(shader.data, "{{K}}", std::to_string(K)); |
| 186 | + replaceAll(shader.data, "{{N}}", std::to_string(N)); |
| 187 | + return shader; |
| 188 | +} |
| 189 | + |
| 190 | +/* Softmax |
| 191 | + * v1: |
| 192 | + * - equivalent to naive softmax with one thread per row |
| 193 | + */ |
| 194 | +static const char *kShaderSoftmax1 = R"( |
| 195 | +@group(0) @binding(0) var<storage, read_write> inp : array<{{precision}}>; |
| 196 | +@group(0) @binding(1) var<storage, read_write> out : array<{{precision}}>; |
| 197 | +@group(0) @binding(2) var<uniform> params : Params; |
| 198 | +struct Params { |
| 199 | + N: u32, |
| 200 | + C: u32, |
| 201 | +}; |
| 202 | +const NEG_INFINITY: f32 = -3.0e38; // WGSL has problem representing -3.4028235e+38 |
| 203 | +@compute @workgroup_size({{workgroupSize}}) |
| 204 | +fn main(@builtin(global_invocation_id) global_id : vec3<u32>) { |
| 205 | + let N : u32 = params.N; |
| 206 | + let C : u32 = params.C; |
| 207 | + let i : u32 = global_id.x; |
| 208 | + if (i < N) { |
| 209 | + let inp_row_start : u32 = i * C; |
| 210 | + var maxval : f32 = NEG_INFINITY; |
| 211 | + // Find the maximum value in the row |
| 212 | + for (var j : u32 = 0u; j < C; j++) { |
| 213 | + let val : f32 = inp[inp_row_start + j]; |
| 214 | + if (val > maxval) { |
| 215 | + maxval = val; |
| 216 | + } |
| 217 | + } |
| 218 | + var sum : f32 = 0.0; |
| 219 | + // Compute the exponentials and sum them |
| 220 | + for (var j : u32 = 0u; j < C; j++) { |
| 221 | + let exp_val : f32 = exp(inp[inp_row_start + j] - maxval); |
| 222 | + out[inp_row_start + j] = exp_val; |
| 223 | + sum += exp_val; |
| 224 | + } |
| 225 | + // Normalize the row to get probabilities |
| 226 | + let norm : f32 = 1.0f / sum; |
| 227 | + for (var j : u32 = 0u; j < C; j++) { |
| 228 | + out[inp_row_start + j] /= sum; |
| 229 | + } |
| 230 | + } |
| 231 | +} |
| 232 | +)"; |
| 233 | + |
| 234 | +} // namespace gpu |
| 235 | + |
| 236 | +#endif // KERNELS_H |
0 commit comments