Skip to content

Commit 1b0deb0

Browse files
committed
Zerod gradient buffer.
1 parent 617d996 commit 1b0deb0

1 file changed

Lines changed: 16 additions & 3 deletions

File tree

openequivariance_extjax/src/libjax_tp_jit.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,9 @@ ffi::Error conv_forward_impl(
448448
if (k.deterministic){
449449
check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm");
450450
}
451+
else {
452+
zero_buffer(*L3_out);
453+
}
451454

452455
if (k.shared_weights)
453456
check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
@@ -495,9 +498,13 @@ ffi::Error conv_backward_impl(
495498
check_tensor(rows, {nnz}, k.idx_dtype, "rows");
496499
check_tensor(cols, {nnz}, k.idx_dtype, "cols");
497500

498-
if (k.deterministic)
501+
if (k.deterministic) {
499502
check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm");
500-
503+
}
504+
else {
505+
zero_buffer(*L1_grad);
506+
}
507+
501508
if (k.shared_weights) {
502509
check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
503510
check_tensor(*W_grad, {k.weight_numel}, k.weight_dtype, "W_grad");
@@ -559,8 +566,13 @@ ffi::Error conv_double_backward_impl(
559566
check_tensor(rows, {nnz}, k.idx_dtype, "rows");
560567
check_tensor(cols, {nnz}, k.idx_dtype, "cols");
561568

562-
if (k.deterministic)
569+
if (k.deterministic) {
563570
check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm");
571+
}
572+
else {
573+
zero_buffer(*L1_grad);
574+
zero_buffer(*L3_dgrad);
575+
}
564576

565577
if (k.shared_weights) {
566578
check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
@@ -571,6 +583,7 @@ ffi::Error conv_double_backward_impl(
571583
}
572584
if(k.shared_weights)
573585
zero_buffer(*W_grad);
586+
574587
jit_kernel->double_backward(
575588
data_ptr(L1_in),
576589
data_ptr(L2_in),

0 commit comments

Comments
 (0)