@@ -215,6 +215,51 @@ ffi::Error tp_forward_impl(
215215 return ffi::Error::Success ();
216216}
217217
218+ ffi::Error tp_backward_impl (
219+ ffi::AnyBuffer L1_in,
220+ ffi::AnyBuffer L2_in,
221+ ffi::AnyBuffer W,
222+ ffi::AnyBuffer L3_grad,
223+ ffi::Result<ffi::AnyBuffer> L1_grad,
224+ ffi::Result<ffi::AnyBuffer> L2_grad,
225+ ffi::Result<ffi::AnyBuffer> W_grad,
226+ cudaStream_t stream,
227+ std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
228+ int64_t hash) {
229+
230+ auto [jit_kernel, k] = compile_kernel_with_caching (
231+ kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false );
232+ const int64_t num_batch = L1_in.dimensions ()[0 ];
233+ check_tensor (L1_in, {num_batch, k.L1_dim }, k.irrep_dtype , " L1_in" );
234+ check_tensor (L2_in, {num_batch, k.L2_dim }, k.irrep_dtype , " L2_in" );
235+ check_tensor (L3_grad, {num_batch, k.L3_dim }, k.irrep_dtype , " L3_grad" );
236+
237+ if (k.shared_weights ) {
238+ check_tensor (W, {k.weight_numel }, k.weight_dtype , " W" );
239+ check_tensor (*W_grad, {k.weight_numel }, k.weight_dtype , " W_grad" );
240+ }
241+ else {
242+ check_tensor (W, {num_batch, k.weight_numel }, k.weight_dtype , " W" );
243+ check_tensor (*W_grad, {num_batch, k.weight_numel }, k.weight_dtype , " W_grad" );
244+ }
245+
246+ if (k.shared_weights ) {
247+ // Need to zero out W_grad
248+ }
249+
250+ jit_kernel->exec_tensor_product_backward (
251+ num_batch,
252+ data_ptr (L1_in),
253+ data_ptr (L1_grad),
254+ data_ptr (L2_in),
255+ data_ptr (L2_grad),
256+ data_ptr (W),
257+ data_ptr (W_grad),
258+ data_ptr (L3_grad),
259+ stream);
260+ return ffi::Error::Success ();
261+ }
262+
218263XLA_FFI_DEFINE_HANDLER_SYMBOL (
219264 tp_forward, tp_forward_impl,
220265 ffi::Ffi::Bind ()
@@ -227,10 +272,26 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
227272 .Attr<int64_t>(" hash" ),
228273 {xla::ffi::Traits::kCmdBufferCompatible }); // cudaGraph enabled
229274
275+ XLA_FFI_DEFINE_HANDLER_SYMBOL (
276+ tp_backward, tp_backward_impl,
277+ ffi::Ffi::Bind ()
278+ .Arg<ffi::AnyBuffer>()
279+ .Arg<ffi::AnyBuffer>()
280+ .Arg<ffi::AnyBuffer>()
281+ .Arg<ffi::AnyBuffer>()
282+ .Ret<ffi::AnyBuffer>()
283+ .Ret<ffi::AnyBuffer>()
284+ .Ret<ffi::AnyBuffer>()
285+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
286+ .Attr<std::string_view>(" kernel" ).Attr<ffi::Dictionary>(" forward_config" ).Attr<ffi::Dictionary>(" backward_config" ).Attr<ffi::Dictionary>(" double_backward_config" ).Attr<ffi::Dictionary>(" kernel_prop" )
287+ .Attr<int64_t>(" hash" ),
288+ {xla::ffi::Traits::kCmdBufferCompatible }); // cudaGraph enabled
289+
230290NB_MODULE (openequivariance_extjax, m) {
231291 m.def (" registrations" , []() {
232292 nb::dict registrations;
233293 registrations[" tp_forward" ] = nb::capsule (reinterpret_cast <void *>(tp_forward));
294+ registrations[" tp_backward" ] = nb::capsule (reinterpret_cast <void *>(tp_backward));
234295 return registrations;
235296 });
236297
0 commit comments