@@ -471,6 +471,64 @@ ffi::Error conv_forward_impl(
471471 return ffi::Error::Success ();
472472}
473473
474+ ffi::Error conv_backward_impl (
475+ ffi::AnyBuffer L1_in,
476+ ffi::AnyBuffer L2_in,
477+ ffi::AnyBuffer W,
478+ ffi::AnyBuffer L3_grad,
479+ ffi::Result<ffi::AnyBuffer> L1_grad,
480+ ffi::Result<ffi::AnyBuffer> L2_grad,
481+ ffi::Result<ffi::AnyBuffer> W_grad,
482+ ffi::AnyBuffer rows,
483+ ffi::AnyBuffer cols,
484+ ffi::AnyBuffer workspace,
485+ ffi::AnyBuffer transpose_perm,
486+ cudaStream_t stream,
487+ std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
488+ int64_t hash) {
489+
490+ auto [jit_kernel, k] = compile_conv_with_caching (
491+ kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true );
492+ const int64_t nnz = rows.dimensions ()[0 ];
493+ const int64_t node_count = L1_in.dimensions ()[0 ];
494+ check_tensor (L1_in, {node_count, k.L1_dim }, k.irrep_dtype , " L1_in" );
495+ check_tensor (L2_in, {nnz, k.L2_dim }, k.irrep_dtype , " L2_in" );
496+ check_tensor (L3_grad, {node_count, k.L3_dim }, k.irrep_dtype , " L3_grad" );
497+ check_tensor (workspace, {k.workspace_size }, k.workspace_dtype , " workspace" );
498+ check_tensor (rows, {nnz}, k.idx_dtype , " rows" );
499+ check_tensor (cols, {nnz}, k.idx_dtype , " cols" );
500+
501+ if (k.deterministic )
502+ check_tensor (transpose_perm, {nnz}, k.idx_dtype , " transpose perm" );
503+
504+ if (k.shared_weights ) {
505+ check_tensor (W, {k.weight_numel }, k.weight_dtype , " W" );
506+ check_tensor (*W_grad, {k.weight_numel }, k.weight_dtype , " W_grad" );
507+ }
508+ else {
509+ check_tensor (W, {nnz, k.weight_numel }, k.weight_dtype , " W" );
510+ check_tensor (*W_grad, {nnz, k.weight_numel }, k.weight_dtype , " W_grad" );
511+ }
512+ if (k.shared_weights )
513+ zero_buffer (*W_grad);
514+
515+ jit_kernel->backward (
516+ data_ptr (L1_in),
517+ data_ptr (L1_grad),
518+ data_ptr (L2_in),
519+ data_ptr (L2_grad),
520+ data_ptr (W),
521+ data_ptr (W_grad),
522+ data_ptr (L3_grad),
523+ data_ptr (rows),
524+ data_ptr (cols),
525+ nnz, node_count,
526+ data_ptr (workspace),
527+ data_ptr (transpose_perm),
528+ stream);
529+ return ffi::Error::Success ();
530+ }
531+
474532XLA_FFI_DEFINE_HANDLER_SYMBOL (
475533 conv_forward, conv_forward_impl,
476534 ffi::Ffi::Bind ()
@@ -487,6 +545,25 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
487545 .Attr<int64_t>(" hash" ),
488546 {xla::ffi::Traits::kCmdBufferCompatible });
489547
548+ XLA_FFI_DEFINE_HANDLER_SYMBOL (
549+ conv_backward, conv_backward_impl,
550+ ffi::Ffi::Bind ()
551+ .Arg<ffi::AnyBuffer>()
552+ .Arg<ffi::AnyBuffer>()
553+ .Arg<ffi::AnyBuffer>()
554+ .Arg<ffi::AnyBuffer>()
555+ .Ret<ffi::AnyBuffer>()
556+ .Ret<ffi::AnyBuffer>()
557+ .Ret<ffi::AnyBuffer>()
558+ .Arg<ffi::AnyBuffer>()
559+ .Arg<ffi::AnyBuffer>()
560+ .Arg<ffi::AnyBuffer>()
561+ .Arg<ffi::AnyBuffer>()
562+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
563+ .Attr<std::string_view>(" kernel" ).Attr<ffi::Dictionary>(" forward_config" ).Attr<ffi::Dictionary>(" backward_config" ).Attr<ffi::Dictionary>(" double_backward_config" ).Attr<ffi::Dictionary>(" kernel_prop" )
564+ .Attr<int64_t>(" hash" ),
565+ {xla::ffi::Traits::kCmdBufferCompatible });
566+
490567NB_MODULE (openequivariance_extjax, m) {
491568 m.def (" registrations" , []() {
492569 nb::dict registrations;
0 commit comments