@@ -118,75 +118,6 @@ fn main(@builtin(global_invocation_id) GlobalInvocationID: vec3<u32>,
118118}
119119)" ;
120120
121- // matrix multiplication (naive implementation)
122- static const char *kShaderMatMul1 = R"(
123- @group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
124- @group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
125- @group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
126- @compute @workgroup_size({{workgroupSize}})
127- fn main(
128- @builtin(global_invocation_id) GlobalInvocationID: vec3<u32>) {
129- let i: u32 = GlobalInvocationID.x / {{N}};
130- let j: u32 = GlobalInvocationID.x % {{N}};
131- if (i < {{M}} && j < {{N}}) {
132- var sum: f32 = 0.0;
133- for (var k: u32 = 0; k < {{K}}; k = k + 1) {
134- sum = sum + A[i * {{K}} + k] * B[k * {{N}} + j];
135- }
136- C[i * {{N}} + j] = sum;
137- }
138- }
139- )" ;
140-
141- static const char *kShaderMatMul2 = R"(
142- @group(0) @binding(0) var<storage, read_write> A: array<f32>;
143- @group(0) @binding(1) var<storage, read_write> B: array<f32>;
144- @group(0) @binding(2) var<storage, read_write> C: array<f32>;
145- var<workgroup> tileA: array<f32, workgroupSizeY * workgroupSizeX>;
146- var<workgroup> tileB: array<f32, workgroupSizeY * workgroupSizeX>;
147- @compute @workgroup_size(workgroupSizeX, workgroupSizeY, 1)
148- fn matmul(
149- @builtin(global_invocation_id) global_id : vec3<u32>,
150- @builtin(local_invocation_id) local_id : vec3<u32>,
151- @builtin(workgroup_id) workgroup_id : vec3<u32>
152- ) {
153- let row = global_id.x;
154- let col = global_id.y;
155- if (row >= {{M}} || col >= {{N}}) {
156- return;
157- }
158- var result: f32 = 0.0;
159- for (var i = 0u; i < {{K}}; i = i + workgroupSizeX) {
160- // Load tiles into shared memory
161- tileA[local_id.y][local_id.x] = A[row][i + local_id.x];
162- tileB[local_id.y][local_id.x] = B[i + local_id.y][col];
163- // Synchronize to make sure the tile is loaded
164- workgroupBarrier();
165- // Perform partial dot product for the current tile
166- for (var k = 0u; k < workgroupSizeX; k = k + 1u) {
167- result = result + tileA[local_id.y][k] * tileB[k][local_id.x];
168- }
169- // Synchronize before loading the next tile
170- workgroupBarrier();
171- }
172- C[row][col] = result;
173- }
174- )" ;
175-
176- /* Generates KernelCode instance for all matmul kernels - pass in
177- * the template code via `shaderRaw`.
178- *
179- * This is intended to be run ahead of time, so is not performance critical.
180- * */
181- KernelCode MatmulShader (size_t workgroupSize, const char *shaderRaw,
182- NumType precision, size_t M, size_t K, size_t N) {
183- KernelCode shader = {shaderRaw, workgroupSize, precision};
184- replaceAll (shader.data , " {{M}}" , std::to_string (M));
185- replaceAll (shader.data , " {{K}}" , std::to_string (K));
186- replaceAll (shader.data , " {{N}}" , std::to_string (N));
187- return shader;
188- }
189-
190121/* Softmax
191122 * v1:
192123 * - equivalent to naive softmax with one thread per row
0 commit comments