@@ -70,11 +70,12 @@ inline int byte_count(ffi::AnyBuffer &buffer) {
7070}
7171
7272#ifdef CUDA_BACKEND
73- void zero_buffer (ffi::AnyBuffer &buffer) {
74- cudaMemset (
73+ void zero_buffer (ffi::AnyBuffer &buffer, cudaStream_t stream ) {
74+ cudaMemsetAsync (
7575 data_ptr (buffer),
7676 0 ,
77- buffer.element_count () * byte_count (buffer));
77+ buffer.element_count () * byte_count (buffer),
78+ stream);
7879}
7980#endif
8081
@@ -303,7 +304,7 @@ ffi::Error tp_backward_impl(
303304 }
304305
305306 if (k.shared_weights ) {
306- zero_buffer (*W_grad);
307+ zero_buffer (*W_grad, stream );
307308 }
308309
309310 jit_kernel->backward (
@@ -354,7 +355,7 @@ ffi::Error tp_double_backward_impl(
354355 }
355356
356357 if (k.shared_weights ) {
357- zero_buffer (*W_grad);
358+ zero_buffer (*W_grad, stream );
358359 }
359360
360361 jit_kernel->double_backward (
@@ -438,6 +439,7 @@ ffi::Error conv_forward_impl(
438439 kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true );
439440 const int64_t nnz = rows.dimensions ()[0 ];
440441 const int64_t node_count = L1_in.dimensions ()[0 ];
442+ void * workspace_ptr = data_ptr (workspace);
441443
442444 check_tensor (L1_in, {node_count, k.L1_dim }, k.irrep_dtype , " L1_in" );
443445 check_tensor (L2_in, {nnz, k.L2_dim }, k.irrep_dtype , " L2_in" );
@@ -449,8 +451,9 @@ ffi::Error conv_forward_impl(
449451 check_tensor (transpose_perm, {nnz}, k.idx_dtype , " transpose perm" );
450452 }
451453 else {
452- zero_buffer (*L3_out) ;
454+ workspace_ptr = nullptr ;
453455 }
456+ zero_buffer (*L3_out, stream);
454457
455458 if (k.shared_weights )
456459 check_tensor (W, {k.weight_numel }, k.weight_dtype , " W" );
@@ -465,7 +468,7 @@ ffi::Error conv_forward_impl(
465468 data_ptr (rows),
466469 data_ptr (cols),
467470 nnz, node_count,
468- data_ptr (workspace) ,
471+ workspace_ptr ,
469472 stream);
470473
471474 return ffi::Error::Success ();
@@ -491,6 +494,8 @@ ffi::Error conv_backward_impl(
491494 kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true );
492495 const int64_t nnz = rows.dimensions ()[0 ];
493496 const int64_t node_count = L1_in.dimensions ()[0 ];
497+ void * workspace_ptr = data_ptr (workspace);
498+
494499 check_tensor (L1_in, {node_count, k.L1_dim }, k.irrep_dtype , " L1_in" );
495500 check_tensor (L2_in, {nnz, k.L2_dim }, k.irrep_dtype , " L2_in" );
496501 check_tensor (L3_grad, {node_count, k.L3_dim }, k.irrep_dtype , " L3_grad" );
@@ -502,8 +507,9 @@ ffi::Error conv_backward_impl(
502507 check_tensor (transpose_perm, {nnz}, k.idx_dtype , " transpose perm" );
503508 }
504509 else {
505- zero_buffer (*L1_grad);
506- }
510+ workspace_ptr = nullptr ;
511+ }
512+ zero_buffer (*L1_grad, stream);
507513
508514 if (k.shared_weights ) {
509515 check_tensor (W, {k.weight_numel }, k.weight_dtype , " W" );
@@ -514,7 +520,7 @@ ffi::Error conv_backward_impl(
514520 check_tensor (*W_grad, {nnz, k.weight_numel }, k.weight_dtype , " W_grad" );
515521 }
516522 if (k.shared_weights )
517- zero_buffer (*W_grad);
523+ zero_buffer (*W_grad, stream );
518524
519525 jit_kernel->backward (
520526 data_ptr (L1_in),
@@ -527,7 +533,7 @@ ffi::Error conv_backward_impl(
527533 data_ptr (rows),
528534 data_ptr (cols),
529535 nnz, node_count,
530- data_ptr (workspace) ,
536+ workspace_ptr ,
531537 data_ptr (transpose_perm),
532538 stream);
533539 return ffi::Error::Success ();
@@ -557,6 +563,8 @@ ffi::Error conv_double_backward_impl(
557563 kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true );
558564 const int64_t nnz = rows.dimensions ()[0 ];
559565 const int64_t node_count = L1_in.dimensions ()[0 ];
566+ void * workspace_ptr = data_ptr (workspace);
567+
560568 check_tensor (L1_in, {node_count, k.L1_dim }, k.irrep_dtype , " L1_in" );
561569 check_tensor (L2_in, {nnz, k.L2_dim }, k.irrep_dtype , " L2_in" );
562570 check_tensor (L3_grad, {node_count, k.L3_dim }, k.irrep_dtype , " L3_grad" );
@@ -570,9 +578,11 @@ ffi::Error conv_double_backward_impl(
570578 check_tensor (transpose_perm, {nnz}, k.idx_dtype , " transpose perm" );
571579 }
572580 else {
573- zero_buffer (*L1_grad);
574- zero_buffer (*L3_dgrad);
581+ workspace_ptr = nullptr ;
575582 }
583+ zero_buffer (*L1_grad, stream);
584+ zero_buffer (*L3_dgrad, stream);
585+
576586
577587 if (k.shared_weights ) {
578588 check_tensor (W, {k.weight_numel }, k.weight_dtype , " W" );
@@ -582,7 +592,7 @@ ffi::Error conv_double_backward_impl(
582592 check_tensor (W_dgrad, {nnz, k.weight_numel }, k.weight_dtype , " W_dgrad" );
583593 }
584594 if (k.shared_weights )
585- zero_buffer (*W_grad);
595+ zero_buffer (*W_grad, stream );
586596
587597 jit_kernel->double_backward (
588598 data_ptr (L1_in),
@@ -599,7 +609,7 @@ ffi::Error conv_double_backward_impl(
599609 data_ptr (rows),
600610 data_ptr (cols),
601611 nnz, node_count,
602- data_ptr (workspace) ,
612+ workspace_ptr ,
603613 data_ptr (transpose_perm),
604614 stream);
605615 return ffi::Error::Success ();
0 commit comments