Skip to content

Commit 16a7b4b

Browse files
committed
remove old matmul implementations
1 parent 7a1eb34 commit 16a7b4b

1 file changed

Lines changed: 0 additions & 69 deletions

File tree

experimental/kernels/kernels.h

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -118,75 +118,6 @@ fn main(@builtin(global_invocation_id) GlobalInvocationID: vec3<u32>,
118118
}
119119
)";
120120

121-
// matrix multiplication (naive implementation)
122-
static const char *kShaderMatMul1 = R"(
123-
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
124-
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
125-
@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
126-
@compute @workgroup_size({{workgroupSize}})
127-
fn main(
128-
@builtin(global_invocation_id) GlobalInvocationID: vec3<u32>) {
129-
let i: u32 = GlobalInvocationID.x / {{N}};
130-
let j: u32 = GlobalInvocationID.x % {{N}};
131-
if (i < {{M}} && j < {{N}}) {
132-
var sum: f32 = 0.0;
133-
for (var k: u32 = 0; k < {{K}}; k = k + 1) {
134-
sum = sum + A[i * {{K}} + k] * B[k * {{N}} + j];
135-
}
136-
C[i * {{N}} + j] = sum;
137-
}
138-
}
139-
)";
140-
141-
static const char *kShaderMatMul2 = R"(
142-
@group(0) @binding(0) var<storage, read_write> A: array<f32>;
143-
@group(0) @binding(1) var<storage, read_write> B: array<f32>;
144-
@group(0) @binding(2) var<storage, read_write> C: array<f32>;
145-
var<workgroup> tileA: array<f32, workgroupSizeY * workgroupSizeX>;
146-
var<workgroup> tileB: array<f32, workgroupSizeY * workgroupSizeX>;
147-
@compute @workgroup_size(workgroupSizeX, workgroupSizeY, 1)
148-
fn matmul(
149-
@builtin(global_invocation_id) global_id : vec3<u32>,
150-
@builtin(local_invocation_id) local_id : vec3<u32>,
151-
@builtin(workgroup_id) workgroup_id : vec3<u32>
152-
) {
153-
let row = global_id.x;
154-
let col = global_id.y;
155-
if (row >= {{M}} || col >= {{N}}) {
156-
return;
157-
}
158-
var result: f32 = 0.0;
159-
for (var i = 0u; i < {{K}}; i = i + workgroupSizeX) {
160-
// Load tiles into shared memory
161-
tileA[local_id.y][local_id.x] = A[row][i + local_id.x];
162-
tileB[local_id.y][local_id.x] = B[i + local_id.y][col];
163-
// Synchronize to make sure the tile is loaded
164-
workgroupBarrier();
165-
// Perform partial dot product for the current tile
166-
for (var k = 0u; k < workgroupSizeX; k = k + 1u) {
167-
result = result + tileA[local_id.y][k] * tileB[k][local_id.x];
168-
}
169-
// Synchronize before loading the next tile
170-
workgroupBarrier();
171-
}
172-
C[row][col] = result;
173-
}
174-
)";
175-
176-
/* Generates KernelCode instance for all matmul kernels - pass in
177-
* the template code via `shaderRaw`.
178-
*
179-
* This is intended to be run ahead of time, so is not performance critical.
180-
* */
181-
KernelCode MatmulShader(size_t workgroupSize, const char *shaderRaw,
182-
NumType precision, size_t M, size_t K, size_t N) {
183-
KernelCode shader = {shaderRaw, workgroupSize, precision};
184-
replaceAll(shader.data, "{{M}}", std::to_string(M));
185-
replaceAll(shader.data, "{{K}}", std::to_string(K));
186-
replaceAll(shader.data, "{{N}}", std::to_string(N));
187-
return shader;
188-
}
189-
190121
/* Softmax
191122
* v1:
192123
* - equivalent to naive softmax with one thread per row

0 commit comments

Comments
 (0)