@@ -286,6 +286,61 @@ ffi::Error tp_backward_impl(
286286 return ffi::Error::Success ();
287287}
288288
289+
290+ ffi::Error tp_double_backward_impl (
291+ ffi::AnyBuffer L1_in,
292+ ffi::AnyBuffer L2_in,
293+ ffi::AnyBuffer W,
294+ ffi::AnyBuffer L3_grad,
295+ ffi::AnyBuffer L1_dgrad,
296+ ffi::AnyBuffer L2_dgrad,
297+ ffi::AnyBuffer W_dgrad,
298+ ffi::Result<ffi::AnyBuffer> L1_grad,
299+ ffi::Result<ffi::AnyBuffer> L2_grad,
300+ ffi::Result<ffi::AnyBuffer> W_grad,
301+ ffi::Result<ffi::AnyBuffer> L3_dgrad,
302+ cudaStream_t stream,
303+ std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
304+ int64_t hash) {
305+
306+ auto [jit_kernel, k] = compile_kernel_with_caching (
307+ kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false );
308+ const int64_t num_batch = L1_in.dimensions ()[0 ];
309+ check_tensor (L1_in, {num_batch, k.L1_dim }, k.irrep_dtype , " L1_in" );
310+ check_tensor (L2_in, {num_batch, k.L2_dim }, k.irrep_dtype , " L2_in" );
311+ check_tensor (L3_grad, {num_batch, k.L3_dim }, k.irrep_dtype , " L3_grad" );
312+ check_tensor (L1_dgrad, {num_batch, k.L1_dim }, k.irrep_dtype , " L1_dgrad" );
313+ check_tensor (L2_dgrad, {num_batch, k.L2_dim }, k.irrep_dtype , " L2_dgrad" );
314+
315+ if (k.shared_weights ){
316+ check_tensor (W, {k.weight_numel }, k.weight_dtype , " W" );
317+ check_tensor (W_dgrad, {k.weight_numel }, k.weight_dtype , " W_dgrad" );
318+ } else {
319+ check_tensor (W, {num_batch, k.weight_numel }, k.weight_dtype , " W" );
320+ check_tensor (W_dgrad, {num_batch, k.weight_numel }, k.weight_dtype , " W_dgrad" );
321+ }
322+
323+ if (k.shared_weights ) {
324+ zero_buffer (*W_grad);
325+ }
326+
327+ jit_kernel->double_backward (
328+ num_batch,
329+ data_ptr (L1_in),
330+ data_ptr (L2_in),
331+ data_ptr (W),
332+ data_ptr (L3_grad),
333+ data_ptr (L1_dgrad),
334+ data_ptr (L2_dgrad),
335+ data_ptr (W_dgrad),
336+ data_ptr (L1_grad),
337+ data_ptr (L2_grad),
338+ data_ptr (W_grad),
339+ data_ptr (L3_dgrad),
340+ stream);
341+ return ffi::Error::Success ();
342+ }
343+
289344XLA_FFI_DEFINE_HANDLER_SYMBOL (
290345 tp_forward, tp_forward_impl,
291346 ffi::Ffi::Bind ()
@@ -311,13 +366,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
311366 .Ctx<ffi::PlatformStream<cudaStream_t>>()
312367 .Attr<std::string_view>(" kernel" ).Attr<ffi::Dictionary>(" forward_config" ).Attr<ffi::Dictionary>(" backward_config" ).Attr<ffi::Dictionary>(" double_backward_config" ).Attr<ffi::Dictionary>(" kernel_prop" )
313368 .Attr<int64_t>(" hash" ),
314- {xla::ffi::Traits::kCmdBufferCompatible }); // cudaGraph enabled
369+ {xla::ffi::Traits::kCmdBufferCompatible });
370+
371+ XLA_FFI_DEFINE_HANDLER_SYMBOL (
372+ tp_double_backward, tp_double_backward_impl,
373+ ffi::Ffi::Bind ()
374+ .Arg<ffi::AnyBuffer>()
375+ .Arg<ffi::AnyBuffer>()
376+ .Arg<ffi::AnyBuffer>()
377+ .Arg<ffi::AnyBuffer>()
378+ .Arg<ffi::AnyBuffer>()
379+ .Arg<ffi::AnyBuffer>()
380+ .Arg<ffi::AnyBuffer>()
381+ .Ret<ffi::AnyBuffer>()
382+ .Ret<ffi::AnyBuffer>()
383+ .Ret<ffi::AnyBuffer>()
384+ .Ret<ffi::AnyBuffer>()
385+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
386+ .Attr<std::string_view>(" kernel" ).Attr<ffi::Dictionary>(" forward_config" ).Attr<ffi::Dictionary>(" backward_config" ).Attr<ffi::Dictionary>(" double_backward_config" ).Attr<ffi::Dictionary>(" kernel_prop" )
387+ .Attr<int64_t>(" hash" ),
388+ {xla::ffi::Traits::kCmdBufferCompatible });
315389
316390NB_MODULE (openequivariance_extjax, m) {
317391 m.def (" registrations" , []() {
318392 nb::dict registrations;
319393 registrations[" tp_forward" ] = nb::capsule (reinterpret_cast <void *>(tp_forward));
320394 registrations[" tp_backward" ] = nb::capsule (reinterpret_cast <void *>(tp_backward));
395+ registrations[" tp_double_backward" ] = nb::capsule (reinterpret_cast <void *>(tp_double_backward));
321396 return registrations;
322397 });
323398
0 commit comments