@@ -529,6 +529,72 @@ ffi::Error conv_backward_impl(
529529 return ffi::Error::Success ();
530530}
531531
532+ ffi::Error conv_double_backward_impl (
533+ ffi::AnyBuffer L1_in,
534+ ffi::AnyBuffer L2_in,
535+ ffi::AnyBuffer W,
536+ ffi::AnyBuffer L3_grad,
537+ ffi::AnyBuffer L1_dgrad,
538+ ffi::AnyBuffer L2_dgrad,
539+ ffi::AnyBuffer W_dgrad,
540+ ffi::Result<ffi::AnyBuffer> L1_grad,
541+ ffi::Result<ffi::AnyBuffer> L2_grad,
542+ ffi::Result<ffi::AnyBuffer> W_grad,
543+ ffi::Result<ffi::AnyBuffer> L3_dgrad,
544+ ffi::AnyBuffer rows,
545+ ffi::AnyBuffer cols,
546+ ffi::AnyBuffer workspace,
547+ ffi::AnyBuffer transpose_perm,
548+ cudaStream_t stream,
549+ std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
550+ int64_t hash) {
551+
552+ auto [jit_kernel, k] = compile_conv_with_caching (
553+ kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true );
554+ const int64_t nnz = rows.dimensions ()[0 ];
555+ const int64_t node_count = L1_in.dimensions ()[0 ];
556+ check_tensor (L1_in, {node_count, k.L1_dim }, k.irrep_dtype , " L1_in" );
557+ check_tensor (L2_in, {nnz, k.L2_dim }, k.irrep_dtype , " L2_in" );
558+ check_tensor (L3_grad, {node_count, k.L3_dim }, k.irrep_dtype , " L3_grad" );
559+ check_tensor (L1_dgrad, {node_count, k.L1_dim }, k.irrep_dtype , " L1_dgrad" );
560+ check_tensor (L2_dgrad, {nnz, k.L2_dim }, k.irrep_dtype , " L2_dgrad" );
561+ check_tensor (workspace, {k.workspace_size }, k.workspace_dtype , " workspace" );
562+ check_tensor (rows, {nnz}, k.idx_dtype , " rows" );
563+ check_tensor (cols, {nnz}, k.idx_dtype , " cols" );
564+
565+ if (k.deterministic )
566+ check_tensor (transpose_perm, {nnz}, k.idx_dtype , " transpose perm" );
567+
568+ if (k.shared_weights ) {
569+ check_tensor (W, {k.weight_numel }, k.weight_dtype , " W" );
570+ check_tensor (W_dgrad, {k.weight_numel }, k.weight_dtype , " W_dgrad" );
571+ } else {
572+ check_tensor (W, {nnz, k.weight_numel }, k.weight_dtype , " W" );
573+ check_tensor (W_dgrad, {nnz, k.weight_numel }, k.weight_dtype , " W_dgrad" );
574+ }
575+ if (k.shared_weights )
576+ zero_buffer (*W_grad);
577+ jit_kernel->double_backward (
578+ data_ptr (L1_in),
579+ data_ptr (L2_in),
580+ data_ptr (W),
581+ data_ptr (L3_grad),
582+ data_ptr (L1_dgrad),
583+ data_ptr (L2_dgrad),
584+ data_ptr (W_dgrad),
585+ data_ptr (L1_grad),
586+ data_ptr (L2_grad),
587+ data_ptr (W_grad),
588+ data_ptr (L3_dgrad),
589+ data_ptr (rows),
590+ data_ptr (cols),
591+ nnz, node_count,
592+ data_ptr (workspace),
593+ data_ptr (transpose_perm),
594+ stream);
595+ return ffi::Error::Success ();
596+ }
597+
532598XLA_FFI_DEFINE_HANDLER_SYMBOL (
533599 conv_forward, conv_forward_impl,
534600 ffi::Ffi::Bind ()
@@ -564,6 +630,30 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
564630 .Attr<int64_t>(" hash" ),
565631 {xla::ffi::Traits::kCmdBufferCompatible });
566632
633+ XLA_FFI_DEFINE_HANDLER_SYMBOL (
634+ conv_double_backward, conv_double_backward_impl,
635+ ffi::Ffi::Bind ()
636+ .Arg<ffi::AnyBuffer>()
637+ .Arg<ffi::AnyBuffer>()
638+ .Arg<ffi::AnyBuffer>()
639+ .Arg<ffi::AnyBuffer>()
640+ .Arg<ffi::AnyBuffer>()
641+ .Arg<ffi::AnyBuffer>()
642+ .Arg<ffi::AnyBuffer>()
643+ .Ret<ffi::AnyBuffer>()
644+ .Ret<ffi::AnyBuffer>()
645+ .Ret<ffi::AnyBuffer>()
646+ .Ret<ffi::AnyBuffer>()
647+ .Arg<ffi::AnyBuffer>()
648+ .Arg<ffi::AnyBuffer>()
649+ .Arg<ffi::AnyBuffer>()
650+ .Arg<ffi::AnyBuffer>()
651+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
652+ .Attr<std::string_view>(" kernel" ).Attr<ffi::Dictionary>(" forward_config" ).Attr<ffi::Dictionary>(" backward_config" ).Attr<ffi::Dictionary>(" double_backward_config" ).Attr<ffi::Dictionary>(" kernel_prop" )
653+ .Attr<int64_t>(" hash" ),
654+ {xla::ffi::Traits::kCmdBufferCompatible });
655+
656+ // --------------------- NB Module --------------------------
567657NB_MODULE (openequivariance_extjax, m) {
568658 m.def (" registrations" , []() {
569659 nb::dict registrations;
@@ -572,6 +662,8 @@ NB_MODULE(openequivariance_extjax, m) {
572662 registrations[" tp_double_backward" ] = nb::capsule (reinterpret_cast <void *>(tp_double_backward));
573663
574664 registrations[" conv_forward" ] = nb::capsule (reinterpret_cast <void *>(conv_forward));
665+ registrations[" conv_backward" ] = nb::capsule (reinterpret_cast <void *>(conv_backward));
666+ registrations[" conv_double_backward" ] = nb::capsule (reinterpret_cast <void *>(conv_double_backward));
575667 return registrations;
576668 });
577669
0 commit comments