Skip to content

Commit ac9b3db

Browse files
committed
Convolution double backward registered.
1 parent 8784dd4 commit ac9b3db

1 file changed

Lines changed: 92 additions & 0 deletions

File tree

openequivariance_extjax/src/libjax_tp_jit.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,72 @@ ffi::Error conv_backward_impl(
529529
return ffi::Error::Success();
530530
}
531531

532+
ffi::Error conv_double_backward_impl(
533+
ffi::AnyBuffer L1_in,
534+
ffi::AnyBuffer L2_in,
535+
ffi::AnyBuffer W,
536+
ffi::AnyBuffer L3_grad,
537+
ffi::AnyBuffer L1_dgrad,
538+
ffi::AnyBuffer L2_dgrad,
539+
ffi::AnyBuffer W_dgrad,
540+
ffi::Result<ffi::AnyBuffer> L1_grad,
541+
ffi::Result<ffi::AnyBuffer> L2_grad,
542+
ffi::Result<ffi::AnyBuffer> W_grad,
543+
ffi::Result<ffi::AnyBuffer> L3_dgrad,
544+
ffi::AnyBuffer rows,
545+
ffi::AnyBuffer cols,
546+
ffi::AnyBuffer workspace,
547+
ffi::AnyBuffer transpose_perm,
548+
cudaStream_t stream,
549+
std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
550+
int64_t hash) {
551+
552+
auto [jit_kernel, k] = compile_conv_with_caching(
553+
kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true);
554+
const int64_t nnz = rows.dimensions()[0];
555+
const int64_t node_count = L1_in.dimensions()[0];
556+
check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in");
557+
check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in");
558+
check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad");
559+
check_tensor(L1_dgrad, {node_count, k.L1_dim}, k.irrep_dtype, "L1_dgrad");
560+
check_tensor(L2_dgrad, {nnz, k.L2_dim}, k.irrep_dtype, "L2_dgrad");
561+
check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace");
562+
check_tensor(rows, {nnz}, k.idx_dtype, "rows");
563+
check_tensor(cols, {nnz}, k.idx_dtype, "cols");
564+
565+
if (k.deterministic)
566+
check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm");
567+
568+
if (k.shared_weights) {
569+
check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
570+
check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad");
571+
} else {
572+
check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W");
573+
check_tensor(W_dgrad, {nnz, k.weight_numel}, k.weight_dtype, "W_dgrad");
574+
}
575+
if(k.shared_weights)
576+
zero_buffer(*W_grad);
577+
jit_kernel->double_backward(
578+
data_ptr(L1_in),
579+
data_ptr(L2_in),
580+
data_ptr(W),
581+
data_ptr(L3_grad),
582+
data_ptr(L1_dgrad),
583+
data_ptr(L2_dgrad),
584+
data_ptr(W_dgrad),
585+
data_ptr(L1_grad),
586+
data_ptr(L2_grad),
587+
data_ptr(W_grad),
588+
data_ptr(L3_dgrad),
589+
data_ptr(rows),
590+
data_ptr(cols),
591+
nnz, node_count,
592+
data_ptr(workspace),
593+
data_ptr(transpose_perm),
594+
stream);
595+
return ffi::Error::Success();
596+
}
597+
532598
XLA_FFI_DEFINE_HANDLER_SYMBOL(
533599
conv_forward, conv_forward_impl,
534600
ffi::Ffi::Bind()
@@ -564,6 +630,30 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
564630
.Attr<int64_t>("hash"),
565631
{xla::ffi::Traits::kCmdBufferCompatible});
566632

633+
XLA_FFI_DEFINE_HANDLER_SYMBOL(
634+
conv_double_backward, conv_double_backward_impl,
635+
ffi::Ffi::Bind()
636+
.Arg<ffi::AnyBuffer>()
637+
.Arg<ffi::AnyBuffer>()
638+
.Arg<ffi::AnyBuffer>()
639+
.Arg<ffi::AnyBuffer>()
640+
.Arg<ffi::AnyBuffer>()
641+
.Arg<ffi::AnyBuffer>()
642+
.Arg<ffi::AnyBuffer>()
643+
.Ret<ffi::AnyBuffer>()
644+
.Ret<ffi::AnyBuffer>()
645+
.Ret<ffi::AnyBuffer>()
646+
.Ret<ffi::AnyBuffer>()
647+
.Arg<ffi::AnyBuffer>()
648+
.Arg<ffi::AnyBuffer>()
649+
.Arg<ffi::AnyBuffer>()
650+
.Arg<ffi::AnyBuffer>()
651+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
652+
.Attr<std::string_view>("kernel").Attr<ffi::Dictionary>("forward_config").Attr<ffi::Dictionary>("backward_config").Attr<ffi::Dictionary>("double_backward_config").Attr<ffi::Dictionary>("kernel_prop")
653+
.Attr<int64_t>("hash"),
654+
{xla::ffi::Traits::kCmdBufferCompatible});
655+
656+
// --------------------- NB Module --------------------------
567657
NB_MODULE(openequivariance_extjax, m) {
568658
m.def("registrations", []() {
569659
nb::dict registrations;
@@ -572,6 +662,8 @@ NB_MODULE(openequivariance_extjax, m) {
572662
registrations["tp_double_backward"] = nb::capsule(reinterpret_cast<void *>(tp_double_backward));
573663

574664
registrations["conv_forward"] = nb::capsule(reinterpret_cast<void *>(conv_forward));
665+
registrations["conv_backward"] = nb::capsule(reinterpret_cast<void *>(conv_backward));
666+
registrations["conv_double_backward"] = nb::capsule(reinterpret_cast<void *>(conv_double_backward));
575667
return registrations;
576668
});
577669

0 commit comments

Comments
 (0)