Skip to content

Commit e78f705

Browse files
committed
Added the forward convolution implementation.
1 parent 2dadb0f commit e78f705

2 files changed

Lines changed: 111 additions & 11 deletions

File tree

openequivariance/openequivariance/extension/convolution.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
#include <stdexcept>
44
#include <iostream>
5-
#include <pybind11/pybind11.h>
6-
#include <pybind11/numpy.h>
75
#include <cstdint>
86

97
struct ConvData {

openequivariance_extjax/src/libjax_tp_jit.cpp

Lines changed: 111 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ namespace ffi = xla::ffi;
2727
#endif
2828

2929
#include "tensorproducts.hpp"
30+
#include "convolution.hpp"
3031

3132
xla::ffi::DataType enum_to_xla_dtype(int64_t i){
3233
switch(i) {
@@ -122,7 +123,13 @@ std::unordered_map<int64_t,
122123
std::pair<
123124
std::unique_ptr<JITTPImpl<JITKernel>>,
124125
KernelProp
125-
>> kernel_cache;
126+
>> tp_cache;
127+
128+
std::unordered_map<int64_t,
129+
std::pair<
130+
std::unique_ptr<JITConvImpl<JITKernel>>,
131+
KernelProp
132+
>> conv_cache;
126133
std::mutex mut;
127134

128135
std::vector<std::string> launch_config_keys = {
@@ -148,7 +155,7 @@ std::unordered_map<string, int64_t> parse_ffi_dict(ffi::Dictionary &dict, const
148155
}
149156

150157
std::pair<JITTPImpl<JITKernel>*, KernelProp>
151-
compile_kernel_with_caching(std::string_view kernel,
158+
compile_tp_with_caching(std::string_view kernel,
152159
ffi::Dictionary forward_config,
153160
ffi::Dictionary backward_config,
154161
ffi::Dictionary double_backward_config,
@@ -158,24 +165,52 @@ std::pair<JITTPImpl<JITKernel>*, KernelProp>
158165

159166
{
160167
const std::lock_guard<std::mutex> lock(mut);
161-
auto it = kernel_cache.find(hash);
162-
if (it == kernel_cache.end()) {
168+
auto it = tp_cache.find(hash);
169+
if (it == tp_cache.end()) {
163170
auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys);
164171
auto jit_tp_impl = std::make_unique<JITTPImpl<JITKernel>>(
165172
std::string(kernel),
166173
parse_ffi_dict(forward_config, launch_config_keys),
167174
parse_ffi_dict(backward_config, launch_config_keys),
168175
parse_ffi_dict(double_backward_config, launch_config_keys),
169176
kernel_prop_map);
170-
kernel_cache.insert({hash,
177+
tp_cache.insert({hash,
171178
std::make_pair(std::move(jit_tp_impl),
172179
KernelProp(kernel_prop_map, is_convolution))});
173-
it = kernel_cache.find(hash);
180+
it = tp_cache.find(hash);
174181
}
175182
return {it->second.first.get(), it->second.second};
176183
}
177184
}
178185

186+
std::pair<JITConvImpl<JITKernel>*, KernelProp>
187+
compile_conv_with_caching(std::string_view kernel,
188+
ffi::Dictionary forward_config,
189+
ffi::Dictionary backward_config,
190+
ffi::Dictionary double_backward_config,
191+
ffi::Dictionary kernel_prop,
192+
int64_t hash,
193+
bool is_convolution) {
194+
195+
{
196+
const std::lock_guard<std::mutex> lock(mut);
197+
auto it = conv_cache.find(hash);
198+
if (it == conv_cache.end()) {
199+
auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys);
200+
auto jit_conv_impl = std::make_unique<JITConvImpl<JITKernel>>(
201+
std::string(kernel),
202+
parse_ffi_dict(forward_config, launch_config_keys),
203+
parse_ffi_dict(backward_config, launch_config_keys),
204+
parse_ffi_dict(double_backward_config, launch_config_keys),
205+
kernel_prop_map);
206+
conv_cache.insert({hash,
207+
std::make_pair(std::move(jit_conv_impl),
208+
KernelProp(kernel_prop_map, is_convolution))});
209+
it = conv_cache.find(hash);
210+
}
211+
return {it->second.first.get(), it->second.second};
212+
}
213+
}
179214

180215
inline void check_tensor(const ffi::AnyBuffer &buffer,
181216
std::initializer_list<int64_t> expected_shape,
@@ -209,6 +244,7 @@ inline void check_tensor(const ffi::AnyBuffer &buffer,
209244
}
210245
}
211246

247+
// --------------------- Tensor Products --------------------------
212248
ffi::Error tp_forward_impl(
213249
ffi::AnyBuffer L1_in,
214250
ffi::AnyBuffer L2_in,
@@ -218,7 +254,7 @@ ffi::Error tp_forward_impl(
218254
std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
219255
int64_t hash) {
220256

221-
auto [jit_kernel, k] = compile_kernel_with_caching(
257+
auto [jit_kernel, k] = compile_tp_with_caching(
222258
kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false);
223259
const int64_t num_batch = L1_in.dimensions()[0];
224260

@@ -253,7 +289,7 @@ ffi::Error tp_backward_impl(
253289
std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
254290
int64_t hash) {
255291

256-
auto [jit_kernel, k] = compile_kernel_with_caching(
292+
auto [jit_kernel, k] = compile_tp_with_caching(
257293
kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false);
258294
const int64_t num_batch = L1_in.dimensions()[0];
259295
check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in");
@@ -303,7 +339,7 @@ ffi::Error tp_double_backward_impl(
303339
std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
304340
int64_t hash) {
305341

306-
auto [jit_kernel, k] = compile_kernel_with_caching(
342+
auto [jit_kernel, k] = compile_tp_with_caching(
307343
kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false);
308344
const int64_t num_batch = L1_in.dimensions()[0];
309345
check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in");
@@ -387,12 +423,78 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
387423
.Attr<int64_t>("hash"),
388424
{xla::ffi::Traits::kCmdBufferCompatible});
389425

426+
// --------------------- Convolution --------------------------
427+
ffi::Error conv_forward_impl(
428+
ffi::AnyBuffer L1_in,
429+
ffi::AnyBuffer L2_in,
430+
ffi::AnyBuffer W,
431+
ffi::AnyBuffer rows,
432+
ffi::AnyBuffer cols,
433+
ffi::AnyBuffer workspace,
434+
ffi::AnyBuffer transpose_perm,
435+
ffi::Result<ffi::AnyBuffer> L3_out,
436+
cudaStream_t stream,
437+
std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
438+
int64_t hash) {
439+
440+
auto [jit_kernel, k] = compile_conv_with_caching(
441+
kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true);
442+
const int64_t nnz = rows.dimensions()[0];
443+
const int64_t node_count = L1_in.dimensions()[0];
444+
445+
check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in");
446+
check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in");
447+
check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace");
448+
check_tensor(rows, {nnz}, k.idx_dtype, "rows");
449+
check_tensor(cols, {nnz}, k.idx_dtype, "cols");
450+
451+
if (k.deterministic){
452+
check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm");
453+
}
454+
455+
if (k.shared_weights)
456+
check_tensor(W, {k.weight_numel}, k.weight_dtype, "W");
457+
else
458+
check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W");
459+
460+
jit_kernel->exec_conv(
461+
data_ptr(L1_in),
462+
data_ptr(L2_in),
463+
data_ptr(W),
464+
data_ptr(L3_out),
465+
data_ptr(rows),
466+
data_ptr(cols),
467+
nnz, node_count,
468+
data_ptr(workspace),
469+
stream);
470+
471+
return ffi::Error::Success();
472+
}
473+
474+
XLA_FFI_DEFINE_HANDLER_SYMBOL(
475+
conv_forward, conv_forward_impl,
476+
ffi::Ffi::Bind()
477+
.Arg<ffi::AnyBuffer>()
478+
.Arg<ffi::AnyBuffer>()
479+
.Arg<ffi::AnyBuffer>()
480+
.Arg<ffi::AnyBuffer>()
481+
.Arg<ffi::AnyBuffer>()
482+
.Arg<ffi::AnyBuffer>()
483+
.Arg<ffi::AnyBuffer>()
484+
.Ret<ffi::AnyBuffer>()
485+
.Ctx<ffi::PlatformStream<cudaStream_t>>()
486+
.Attr<std::string_view>("kernel").Attr<ffi::Dictionary>("forward_config").Attr<ffi::Dictionary>("backward_config").Attr<ffi::Dictionary>("double_backward_config").Attr<ffi::Dictionary>("kernel_prop")
487+
.Attr<int64_t>("hash"),
488+
{xla::ffi::Traits::kCmdBufferCompatible});
489+
390490
NB_MODULE(openequivariance_extjax, m) {
391491
m.def("registrations", []() {
392492
nb::dict registrations;
393493
registrations["tp_forward"] = nb::capsule(reinterpret_cast<void *>(tp_forward));
394494
registrations["tp_backward"] = nb::capsule(reinterpret_cast<void *>(tp_backward));
395495
registrations["tp_double_backward"] = nb::capsule(reinterpret_cast<void *>(tp_double_backward));
496+
497+
registrations["conv_forward"] = nb::capsule(reinterpret_cast<void *>(conv_forward));
396498
return registrations;
397499
});
398500

0 commit comments

Comments
 (0)