@@ -27,6 +27,7 @@ namespace ffi = xla::ffi;
2727#endif
2828
2929#include " tensorproducts.hpp"
30+ #include " convolution.hpp"
3031
3132xla::ffi::DataType enum_to_xla_dtype (int64_t i){
3233 switch (i) {
@@ -122,7 +123,13 @@ std::unordered_map<int64_t,
122123 std::pair<
123124 std::unique_ptr<JITTPImpl<JITKernel>>,
124125 KernelProp
125- >> kernel_cache;
126+ >> tp_cache;
127+
128+ std::unordered_map<int64_t ,
129+ std::pair<
130+ std::unique_ptr<JITConvImpl<JITKernel>>,
131+ KernelProp
132+ >> conv_cache;
126133std::mutex mut;
127134
128135std::vector<std::string> launch_config_keys = {
@@ -148,7 +155,7 @@ std::unordered_map<string, int64_t> parse_ffi_dict(ffi::Dictionary &dict, const
148155}
149156
150157std::pair<JITTPImpl<JITKernel>*, KernelProp>
151- compile_kernel_with_caching (std::string_view kernel,
158+ compile_tp_with_caching (std::string_view kernel,
152159 ffi::Dictionary forward_config,
153160 ffi::Dictionary backward_config,
154161 ffi::Dictionary double_backward_config,
@@ -158,24 +165,52 @@ std::pair<JITTPImpl<JITKernel>*, KernelProp>
158165
159166 {
160167 const std::lock_guard<std::mutex> lock (mut);
161- auto it = kernel_cache .find (hash);
162- if (it == kernel_cache .end ()) {
168+ auto it = tp_cache .find (hash);
169+ if (it == tp_cache .end ()) {
163170 auto kernel_prop_map = parse_ffi_dict (kernel_prop, kernel_prop_keys);
164171 auto jit_tp_impl = std::make_unique<JITTPImpl<JITKernel>>(
165172 std::string (kernel),
166173 parse_ffi_dict (forward_config, launch_config_keys),
167174 parse_ffi_dict (backward_config, launch_config_keys),
168175 parse_ffi_dict (double_backward_config, launch_config_keys),
169176 kernel_prop_map);
170- kernel_cache .insert ({hash,
177+ tp_cache .insert ({hash,
171178 std::make_pair (std::move (jit_tp_impl),
172179 KernelProp (kernel_prop_map, is_convolution))});
173- it = kernel_cache .find (hash);
180+ it = tp_cache .find (hash);
174181 }
175182 return {it->second .first .get (), it->second .second };
176183 }
177184}
178185
186+ std::pair<JITConvImpl<JITKernel>*, KernelProp>
187+ compile_conv_with_caching (std::string_view kernel,
188+ ffi::Dictionary forward_config,
189+ ffi::Dictionary backward_config,
190+ ffi::Dictionary double_backward_config,
191+ ffi::Dictionary kernel_prop,
192+ int64_t hash,
193+ bool is_convolution) {
194+
195+ {
196+ const std::lock_guard<std::mutex> lock (mut);
197+ auto it = conv_cache.find (hash);
198+ if (it == conv_cache.end ()) {
199+ auto kernel_prop_map = parse_ffi_dict (kernel_prop, kernel_prop_keys);
200+ auto jit_conv_impl = std::make_unique<JITConvImpl<JITKernel>>(
201+ std::string (kernel),
202+ parse_ffi_dict (forward_config, launch_config_keys),
203+ parse_ffi_dict (backward_config, launch_config_keys),
204+ parse_ffi_dict (double_backward_config, launch_config_keys),
205+ kernel_prop_map);
206+ conv_cache.insert ({hash,
207+ std::make_pair (std::move (jit_conv_impl),
208+ KernelProp (kernel_prop_map, is_convolution))});
209+ it = conv_cache.find (hash);
210+ }
211+ return {it->second .first .get (), it->second .second };
212+ }
213+ }
179214
180215inline void check_tensor (const ffi::AnyBuffer &buffer,
181216 std::initializer_list<int64_t > expected_shape,
@@ -209,6 +244,7 @@ inline void check_tensor(const ffi::AnyBuffer &buffer,
209244 }
210245}
211246
247+ // --------------------- Tensor Products --------------------------
212248ffi::Error tp_forward_impl (
213249 ffi::AnyBuffer L1_in,
214250 ffi::AnyBuffer L2_in,
@@ -218,7 +254,7 @@ ffi::Error tp_forward_impl(
218254 std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
219255 int64_t hash) {
220256
221- auto [jit_kernel, k] = compile_kernel_with_caching (
257+ auto [jit_kernel, k] = compile_tp_with_caching (
222258 kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false );
223259 const int64_t num_batch = L1_in.dimensions ()[0 ];
224260
@@ -253,7 +289,7 @@ ffi::Error tp_backward_impl(
253289 std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
254290 int64_t hash) {
255291
256- auto [jit_kernel, k] = compile_kernel_with_caching (
292+ auto [jit_kernel, k] = compile_tp_with_caching (
257293 kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false );
258294 const int64_t num_batch = L1_in.dimensions ()[0 ];
259295 check_tensor (L1_in, {num_batch, k.L1_dim }, k.irrep_dtype , " L1_in" );
@@ -303,7 +339,7 @@ ffi::Error tp_double_backward_impl(
303339 std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
304340 int64_t hash) {
305341
306- auto [jit_kernel, k] = compile_kernel_with_caching (
342+ auto [jit_kernel, k] = compile_tp_with_caching (
307343 kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false );
308344 const int64_t num_batch = L1_in.dimensions ()[0 ];
309345 check_tensor (L1_in, {num_batch, k.L1_dim }, k.irrep_dtype , " L1_in" );
@@ -387,12 +423,78 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
387423 .Attr<int64_t>(" hash" ),
388424 {xla::ffi::Traits::kCmdBufferCompatible });
389425
426+ // --------------------- Convolution --------------------------
427+ ffi::Error conv_forward_impl (
428+ ffi::AnyBuffer L1_in,
429+ ffi::AnyBuffer L2_in,
430+ ffi::AnyBuffer W,
431+ ffi::AnyBuffer rows,
432+ ffi::AnyBuffer cols,
433+ ffi::AnyBuffer workspace,
434+ ffi::AnyBuffer transpose_perm,
435+ ffi::Result<ffi::AnyBuffer> L3_out,
436+ cudaStream_t stream,
437+ std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop,
438+ int64_t hash) {
439+
440+ auto [jit_kernel, k] = compile_conv_with_caching (
441+ kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true );
442+ const int64_t nnz = rows.dimensions ()[0 ];
443+ const int64_t node_count = L1_in.dimensions ()[0 ];
444+
445+ check_tensor (L1_in, {node_count, k.L1_dim }, k.irrep_dtype , " L1_in" );
446+ check_tensor (L2_in, {nnz, k.L2_dim }, k.irrep_dtype , " L2_in" );
447+ check_tensor (workspace, {k.workspace_size }, k.workspace_dtype , " workspace" );
448+ check_tensor (rows, {nnz}, k.idx_dtype , " rows" );
449+ check_tensor (cols, {nnz}, k.idx_dtype , " cols" );
450+
451+ if (k.deterministic ){
452+ check_tensor (transpose_perm, {nnz}, k.idx_dtype , " transpose perm" );
453+ }
454+
455+ if (k.shared_weights )
456+ check_tensor (W, {k.weight_numel }, k.weight_dtype , " W" );
457+ else
458+ check_tensor (W, {nnz, k.weight_numel }, k.weight_dtype , " W" );
459+
460+ jit_kernel->exec_conv (
461+ data_ptr (L1_in),
462+ data_ptr (L2_in),
463+ data_ptr (W),
464+ data_ptr (L3_out),
465+ data_ptr (rows),
466+ data_ptr (cols),
467+ nnz, node_count,
468+ data_ptr (workspace),
469+ stream);
470+
471+ return ffi::Error::Success ();
472+ }
473+
474+ XLA_FFI_DEFINE_HANDLER_SYMBOL (
475+ conv_forward, conv_forward_impl,
476+ ffi::Ffi::Bind ()
477+ .Arg<ffi::AnyBuffer>()
478+ .Arg<ffi::AnyBuffer>()
479+ .Arg<ffi::AnyBuffer>()
480+ .Arg<ffi::AnyBuffer>()
481+ .Arg<ffi::AnyBuffer>()
482+ .Arg<ffi::AnyBuffer>()
483+ .Arg<ffi::AnyBuffer>()
484+ .Ret<ffi::AnyBuffer>()
485+ .Ctx<ffi::PlatformStream<cudaStream_t>>()
486+ .Attr<std::string_view>(" kernel" ).Attr<ffi::Dictionary>(" forward_config" ).Attr<ffi::Dictionary>(" backward_config" ).Attr<ffi::Dictionary>(" double_backward_config" ).Attr<ffi::Dictionary>(" kernel_prop" )
487+ .Attr<int64_t>(" hash" ),
488+ {xla::ffi::Traits::kCmdBufferCompatible });
489+
390490NB_MODULE (openequivariance_extjax, m) {
391491 m.def (" registrations" , []() {
392492 nb::dict registrations;
393493 registrations[" tp_forward" ] = nb::capsule (reinterpret_cast <void *>(tp_forward));
394494 registrations[" tp_backward" ] = nb::capsule (reinterpret_cast <void *>(tp_backward));
395495 registrations[" tp_double_backward" ] = nb::capsule (reinterpret_cast <void *>(tp_double_backward));
496+
497+ registrations[" conv_forward" ] = nb::capsule (reinterpret_cast <void *>(conv_forward));
396498 return registrations;
397499 });
398500
0 commit comments