Skip to content

Commit 2dadb0f

Browse files
committed
Wrapped the double-backward pass.
1 parent 136c9f6 commit 2dadb0f

1 file changed

Lines changed: 76 additions & 1 deletion

File tree

openequivariance_extjax/src/libjax_tp_jit.cpp

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,61 @@ ffi::Error tp_backward_impl(
286286
return ffi::Error::Success();
287287
}
288288

289+
290+
ffi::Error tp_double_backward_impl(
291+
ffi::AnyBuffer L1_in,
292+
ffi::AnyBuffer L2_in,
293+
ffi::AnyBuffer W,
294+
ffi::AnyBuffer L3_grad,
295+
ffi::AnyBuffer L1_dgrad,
296+
ffi::AnyBuffer L2_dgrad,
297+
ffi::AnyBuffer W_dgrad,
298+
ffi::Result<ffi::AnyBuffer> L1_grad,
299+
ffi::Result<ffi::AnyBuffer> L2_grad,
300+
ffi::Result<ffi::AnyBuffer> W_grad,
301+
ffi::Result<ffi::AnyBuffer> L3_dgrad,
302+
cudaStream_t stream,
303+
std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
304+
int64_t hash) {
305+
306+
auto [jit_kernel, k] = compile_kernel_with_caching(
307+
kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false);
308+
const int64_t num_batch = L1_in.dimensions()[0];
309+
check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in");
310+
check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in");
311+
check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad");
312+
check_tensor(L1_dgrad, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_dgrad");
313+
check_tensor(L2_dgrad, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_dgrad");
314+
315+
if (k.shared_weights){
316+
check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
317+
check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad");
318+
} else {
319+
check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W");
320+
check_tensor(W_dgrad, {num_batch, k.weight_numel}, k.weight_dtype, "W_dgrad");
321+
}
322+
323+
if (k.shared_weights) {
324+
zero_buffer(*W_grad);
325+
}
326+
327+
jit_kernel->double_backward(
328+
num_batch,
329+
data_ptr(L1_in),
330+
data_ptr(L2_in),
331+
data_ptr(W),
332+
data_ptr(L3_grad),
333+
data_ptr(L1_dgrad),
334+
data_ptr(L2_dgrad),
335+
data_ptr(W_dgrad),
336+
data_ptr(L1_grad),
337+
data_ptr(L2_grad),
338+
data_ptr(W_grad),
339+
data_ptr(L3_dgrad),
340+
stream);
341+
return ffi::Error::Success();
342+
}
343+
289344
XLA_FFI_DEFINE_HANDLER_SYMBOL(
290345
tp_forward, tp_forward_impl,
291346
ffi::Ffi::Bind()
@@ -311,13 +366,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
311366
.Ctx<ffi::PlatformStream<cudaStream_t>>()
312367
.Attr<std::string_view>("kernel").Attr<ffi::Dictionary>("forward_config").Attr<ffi::Dictionary>("backward_config").Attr<ffi::Dictionary>("double_backward_config").Attr<ffi::Dictionary>("kernel_prop")
313368
.Attr<int64_t>("hash"),
314-
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
369+
{xla::ffi::Traits::kCmdBufferCompatible});
370+
371+
XLA_FFI_DEFINE_HANDLER_SYMBOL(
372+
tp_double_backward, tp_double_backward_impl,
373+
ffi::Ffi::Bind()
374+
.Arg<ffi::AnyBuffer>()
375+
.Arg<ffi::AnyBuffer>()
376+
.Arg<ffi::AnyBuffer>()
377+
.Arg<ffi::AnyBuffer>()
378+
.Arg<ffi::AnyBuffer>()
379+
.Arg<ffi::AnyBuffer>()
380+
.Arg<ffi::AnyBuffer>()
381+
.Ret<ffi::AnyBuffer>()
382+
.Ret<ffi::AnyBuffer>()
383+
.Ret<ffi::AnyBuffer>()
384+
.Ret<ffi::AnyBuffer>()
385+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
386+
.Attr<std::string_view>("kernel").Attr<ffi::Dictionary>("forward_config").Attr<ffi::Dictionary>("backward_config").Attr<ffi::Dictionary>("double_backward_config").Attr<ffi::Dictionary>("kernel_prop")
387+
.Attr<int64_t>("hash"),
388+
{xla::ffi::Traits::kCmdBufferCompatible});
315389

316390
NB_MODULE(openequivariance_extjax, m) {
317391
m.def("registrations", []() {
318392
nb::dict registrations;
319393
registrations["tp_forward"] = nb::capsule(reinterpret_cast<void *>(tp_forward));
320394
registrations["tp_backward"] = nb::capsule(reinterpret_cast<void *>(tp_backward));
395+
registrations["tp_double_backward"] = nb::capsule(reinterpret_cast<void *>(tp_double_backward));
321396
return registrations;
322397
});
323398

0 commit comments

Comments
 (0)