Skip to content

Commit bdba854

Browse files
Fix the data alignment of softmax_forward
1 parent 4aab602 commit bdba854

3 files changed

Lines changed: 18 additions & 11 deletions

File tree

experimental/kernels/kernels.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,17 @@ static const char *kShaderSoftmax1 = R"(
165165
struct Params {
166166
N: u32,
167167
C: u32,
168+
Cp: u32,
168169
};
169170
const NEG_INFINITY: f32 = -3.0e38; // WGSL has problem representing -3.4028235e+38
170171
@compute @workgroup_size({{workgroupSize}})
171172
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
172173
let N : u32 = params.N;
173174
let C : u32 = params.C;
175+
let Cp : u32 = params.Cp;
174176
let i : u32 = global_id.x;
175177
if (i < N) {
176-
let inp_row_start : u32 = i * C;
178+
let inp_row_start : u32 = i * Cp;
177179
var maxval : f32 = NEG_INFINITY;
178180
// Find the maximum value in the row
179181
for (var j : u32 = 0u; j < C; j++) {
@@ -194,6 +196,9 @@ fn main(@builtin(global_invocation_id) global_id : vec3<u32>) {
194196
for (var j : u32 = 0u; j < C; j++) {
195197
out[inp_row_start + j] /= sum;
196198
}
199+
for (var j : u32 = C; j < Cp; j++) {
200+
out[inp_row_start + j] = 0;
201+
}
197202
}
198203
}
199204
)";

experimental/kernels/kernels_c.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -379,25 +379,28 @@ void RESIDUAL_BACKWARD_GPU(float* dinp1, float* dinp2, float* dout, int N){
379379
toCPU(ctx, dinp2_o, dinp2, n * sizeof(float));
380380
}
381381

382-
void SOFTMAX_FORWARD_GPU(float* probs, float* logits, int b, int t, int v, int vp) {
382+
void SOFTMAX_FORWARD_GPU(float* probs, float* logits, int B, int T, int V, int Vp) {
383383
struct SoftmaxParam {
384384
uint32_t N;
385385
uint32_t C;
386+
uint32_t Cp;
386387
};
387-
uint32_t B = static_cast<uint32_t>(b);
388-
uint32_t T = static_cast<uint32_t>(t);
389-
uint32_t C = static_cast<uint32_t>(v);
388+
uint32_t b = static_cast<uint32_t>(B);
389+
uint32_t t = static_cast<uint32_t>(T);
390+
uint32_t c = static_cast<uint32_t>(V);
391+
uint32_t cp = static_cast<uint32_t>(Vp);
390392
Context ctx = createContext();
391-
Tensor input = createTensor(ctx, {B * T, C}, kf32, logits);
392-
Tensor output = createTensor(ctx, {B * T, C}, kf32);
393+
Tensor input = createTensor(ctx, {b * t, cp}, kf32, logits);
394+
Tensor output = createTensor(ctx, {b * t, cp}, kf32);
393395
std::promise<void> promise;
394396
std::future<void> future = promise.get_future();
397+
assert( (B*T) % 256 == 0);
395398
Kernel op = createKernel(
396399
ctx, {kShaderSoftmax1, 256, kf32}, Bindings{input, output},
397-
Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{B * T, C});
400+
Shape{cdiv(B * T, 256), 1, 1}, SoftmaxParam{b * t, c, cp});
398401
dispatchKernel(ctx, op, promise);
399402
wait(ctx, future);
400-
toCPU(ctx, output, probs, sizeof(float)*B*T*C);
403+
toCPU(ctx, output, probs, sizeof(float)*b*t*cp);
401404
}
402405

403406
void CROSSENTROPY_FORWARD_GPU(float* losses,

experimental/kernels/kernels_c.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ extern "C" {
1818
#define USE_GPU_FOR_GELU_BACKWARD 1
1919
#define USE_GPU_FOR_RESIDUAL_FORWARD 1
2020
#define USE_GPU_FOR_RESIDUAL_BACKWARD 1
21-
// -- Note: softmax_forward_gpu is implemented, but there are some mismatches when enabled. --
22-
// #define USE_GPU_FOR_SOFTMAX_FORWARD 1
21+
#define USE_GPU_FOR_SOFTMAX_FORWARD 1
2322
#define USE_GPU_FOR_CROSSENTROPY_FORWARD 1
2423
// #define USE_GPU_FOR_CROSSENTROPY_SOFTMAX_BACKWARD 1
2524

0 commit comments

Comments
 (0)