Skip to content

Commit 6790bd0

Browse files
committed
Added the backward pass.
1 parent 5c7a828 commit 6790bd0

1 file changed

Lines changed: 61 additions & 0 deletions

File tree

openequivariance_extjax/src/libjax_tp_jit.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,51 @@ ffi::Error tp_forward_impl(
215215
return ffi::Error::Success();
216216
}
217217

218+
ffi::Error tp_backward_impl(
219+
ffi::AnyBuffer L1_in,
220+
ffi::AnyBuffer L2_in,
221+
ffi::AnyBuffer W,
222+
ffi::AnyBuffer L3_grad,
223+
ffi::Result<ffi::AnyBuffer> L1_grad,
224+
ffi::Result<ffi::AnyBuffer> L2_grad,
225+
ffi::Result<ffi::AnyBuffer> W_grad,
226+
cudaStream_t stream,
227+
std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
228+
int64_t hash) {
229+
230+
auto [jit_kernel, k] = compile_kernel_with_caching(
231+
kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false);
232+
const int64_t num_batch = L1_in.dimensions()[0];
233+
check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in");
234+
check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in");
235+
check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad");
236+
237+
if (k.shared_weights) {
238+
check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
239+
check_tensor(*W_grad, {k.weight_numel}, k.weight_dtype, "W_grad");
240+
}
241+
else {
242+
check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W");
243+
check_tensor(*W_grad, {num_batch, k.weight_numel}, k.weight_dtype, "W_grad");
244+
}
245+
246+
if (k.shared_weights) {
247+
// Need to zero out W_grad
248+
}
249+
250+
jit_kernel->exec_tensor_product_backward(
251+
num_batch,
252+
data_ptr(L1_in),
253+
data_ptr(L1_grad),
254+
data_ptr(L2_in),
255+
data_ptr(L2_grad),
256+
data_ptr(W),
257+
data_ptr(W_grad),
258+
data_ptr(L3_grad),
259+
stream);
260+
return ffi::Error::Success();
261+
}
262+
218263
XLA_FFI_DEFINE_HANDLER_SYMBOL(
219264
tp_forward, tp_forward_impl,
220265
ffi::Ffi::Bind()
@@ -227,10 +272,26 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
227272
.Attr<int64_t>("hash"),
228273
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
229274

275+
XLA_FFI_DEFINE_HANDLER_SYMBOL(
276+
tp_backward, tp_backward_impl,
277+
ffi::Ffi::Bind()
278+
.Arg<ffi::AnyBuffer>()
279+
.Arg<ffi::AnyBuffer>()
280+
.Arg<ffi::AnyBuffer>()
281+
.Arg<ffi::AnyBuffer>()
282+
.Ret<ffi::AnyBuffer>()
283+
.Ret<ffi::AnyBuffer>()
284+
.Ret<ffi::AnyBuffer>()
285+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
286+
.Attr<std::string_view>("kernel").Attr<ffi::Dictionary>("forward_config").Attr<ffi::Dictionary>("backward_config").Attr<ffi::Dictionary>("double_backward_config").Attr<ffi::Dictionary>("kernel_prop")
287+
.Attr<int64_t>("hash"),
288+
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
289+
230290
NB_MODULE(openequivariance_extjax, m) {
231291
m.def("registrations", []() {
232292
nb::dict registrations;
233293
registrations["tp_forward"] = nb::capsule(reinterpret_cast<void *>(tp_forward));
294+
registrations["tp_backward"] = nb::capsule(reinterpret_cast<void *>(tp_backward));
234295
return registrations;
235296
});
236297

0 commit comments

Comments
 (0)