@@ -448,6 +448,9 @@ ffi::Error conv_forward_impl(
448448 if (k.deterministic ){
449449 check_tensor (transpose_perm, {nnz}, k.idx_dtype , " transpose perm" );
450450 }
451+ else {
452+ zero_buffer (*L3_out);
453+ }
451454
452455 if (k.shared_weights )
453456 check_tensor (W, {k.weight_numel }, k.weight_dtype , " W" );
@@ -495,9 +498,13 @@ ffi::Error conv_backward_impl(
495498 check_tensor (rows, {nnz}, k.idx_dtype , " rows" );
496499 check_tensor (cols, {nnz}, k.idx_dtype , " cols" );
497500
498- if (k.deterministic )
501+ if (k.deterministic ) {
499502 check_tensor (transpose_perm, {nnz}, k.idx_dtype , " transpose perm" );
500-
503+ }
504+ else {
505+ zero_buffer (*L1_grad);
506+ }
507+
501508 if (k.shared_weights ) {
502509 check_tensor (W, {k.weight_numel }, k.weight_dtype , " W" );
503510 check_tensor (*W_grad, {k.weight_numel }, k.weight_dtype , " W_grad" );
@@ -559,8 +566,13 @@ ffi::Error conv_double_backward_impl(
559566 check_tensor (rows, {nnz}, k.idx_dtype , " rows" );
560567 check_tensor (cols, {nnz}, k.idx_dtype , " cols" );
561568
562- if (k.deterministic )
569+ if (k.deterministic ) {
563570 check_tensor (transpose_perm, {nnz}, k.idx_dtype , " transpose perm" );
571+ }
572+ else {
573+ zero_buffer (*L1_grad);
574+ zero_buffer (*L3_dgrad);
575+ }
564576
565577 if (k.shared_weights ) {
566578 check_tensor (W, {k.weight_numel }, k.weight_dtype , " W" );
@@ -571,6 +583,7 @@ ffi::Error conv_double_backward_impl(
571583 }
572584 if (k.shared_weights )
573585 zero_buffer (*W_grad);
586+
574587 jit_kernel->double_backward (
575588 data_ptr (L1_in),
576589 data_ptr (L2_in),
0 commit comments