Skip to content

Commit 8784dd4

Browse files
committed
Backward convolution implemented.
1 parent e78f705 commit 8784dd4

1 file changed

Lines changed: 77 additions & 0 deletions

File tree

openequivariance_extjax/src/libjax_tp_jit.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,64 @@ ffi::Error conv_forward_impl(
471471
return ffi::Error::Success();
472472
}
473473

474+
ffi::Error conv_backward_impl(
475+
ffi::AnyBuffer L1_in,
476+
ffi::AnyBuffer L2_in,
477+
ffi::AnyBuffer W,
478+
ffi::AnyBuffer L3_grad,
479+
ffi::Result<ffi::AnyBuffer> L1_grad,
480+
ffi::Result<ffi::AnyBuffer> L2_grad,
481+
ffi::Result<ffi::AnyBuffer> W_grad,
482+
ffi::AnyBuffer rows,
483+
ffi::AnyBuffer cols,
484+
ffi::AnyBuffer workspace,
485+
ffi::AnyBuffer transpose_perm,
486+
cudaStream_t stream,
487+
std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
488+
int64_t hash) {
489+
490+
auto [jit_kernel, k] = compile_conv_with_caching(
491+
kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true);
492+
const int64_t nnz = rows.dimensions()[0];
493+
const int64_t node_count = L1_in.dimensions()[0];
494+
check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in");
495+
check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in");
496+
check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad");
497+
check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace");
498+
check_tensor(rows, {nnz}, k.idx_dtype, "rows");
499+
check_tensor(cols, {nnz}, k.idx_dtype, "cols");
500+
501+
if (k.deterministic)
502+
check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm");
503+
504+
if (k.shared_weights) {
505+
check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
506+
check_tensor(*W_grad, {k.weight_numel}, k.weight_dtype, "W_grad");
507+
}
508+
else {
509+
check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W");
510+
check_tensor(*W_grad, {nnz, k.weight_numel}, k.weight_dtype, "W_grad");
511+
}
512+
if(k.shared_weights)
513+
zero_buffer(*W_grad);
514+
515+
jit_kernel->backward(
516+
data_ptr(L1_in),
517+
data_ptr(L1_grad),
518+
data_ptr(L2_in),
519+
data_ptr(L2_grad),
520+
data_ptr(W),
521+
data_ptr(W_grad),
522+
data_ptr(L3_grad),
523+
data_ptr(rows),
524+
data_ptr(cols),
525+
nnz, node_count,
526+
data_ptr(workspace),
527+
data_ptr(transpose_perm),
528+
stream);
529+
return ffi::Error::Success();
530+
}
531+
474532
XLA_FFI_DEFINE_HANDLER_SYMBOL(
475533
conv_forward, conv_forward_impl,
476534
ffi::Ffi::Bind()
@@ -487,6 +545,25 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
487545
.Attr<int64_t>("hash"),
488546
{xla::ffi::Traits::kCmdBufferCompatible});
489547

548+
XLA_FFI_DEFINE_HANDLER_SYMBOL(
549+
conv_backward, conv_backward_impl,
550+
ffi::Ffi::Bind()
551+
.Arg<ffi::AnyBuffer>()
552+
.Arg<ffi::AnyBuffer>()
553+
.Arg<ffi::AnyBuffer>()
554+
.Arg<ffi::AnyBuffer>()
555+
.Ret<ffi::AnyBuffer>()
556+
.Ret<ffi::AnyBuffer>()
557+
.Ret<ffi::AnyBuffer>()
558+
.Arg<ffi::AnyBuffer>()
559+
.Arg<ffi::AnyBuffer>()
560+
.Arg<ffi::AnyBuffer>()
561+
.Arg<ffi::AnyBuffer>()
562+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
563+
.Attr<std::string_view>("kernel").Attr<ffi::Dictionary>("forward_config").Attr<ffi::Dictionary>("backward_config").Attr<ffi::Dictionary>("double_backward_config").Attr<ffi::Dictionary>("kernel_prop")
564+
.Attr<int64_t>("hash"),
565+
{xla::ffi::Traits::kCmdBufferCompatible});
566+
490567
NB_MODULE(openequivariance_extjax, m) {
491568
m.def("registrations", []() {
492569
nb::dict registrations;

0 commit comments

Comments
 (0)