1111#include " nanobind/nanobind.h"
1212#include " xla/ffi/api/ffi.h"
1313
14+ namespace nb = nanobind;
15+ namespace ffi = xla::ffi;
16+
1417#define CUDA_BACKEND // Stick to CUDA for now
1518
1619#ifdef CUDA_BACKEND
2023 using GPU_Allocator = CUDA_Allocator;
2124
2225 template <typename T>
23- using GroupMM = GroupMMCUDA<T>;
26+ using GroupMM = GroupMMCUDA<T>;
2427#endif
2528
2629#include " tensorproducts.hpp"
2730
28- namespace nb = nanobind;
29- namespace ffi = xla::ffi;
30-
3131xla::ffi::DataType enum_to_xla_dtype (int64_t i){
3232 switch (i) {
3333 case 1 :
@@ -45,22 +45,48 @@ xla::ffi::DataType enum_to_xla_dtype(int64_t i){
4545}
4646
4747inline void * data_ptr (ffi::AnyBuffer &buffer) {
48- if (buffer.element_type () == xla::ffi::DataType::F32)
49- return reinterpret_cast <void *>(buffer.typed_data <float >());
50- else if (buffer.element_type () == xla::ffi::DataType::F64)
51- return reinterpret_cast <void *>(buffer.typed_data <double >());
52- else if (buffer.element_type () == xla::ffi::DataType::S64)
53- return reinterpret_cast <void *>(buffer.typed_data <int64_t >());
54- else if (buffer.element_type () == xla::ffi::DataType::U8)
55- return reinterpret_cast <void *>(buffer.typed_data <uint8_t >());
56- else
57- throw logic_error (" Unsupported tensor datatype!" );
48+ switch (buffer.element_type ()) {
49+ case xla::ffi::DataType::F32:
50+ return reinterpret_cast <void *>(buffer.typed_data <float >());
51+ case xla::ffi::DataType::F64:
52+ return reinterpret_cast <void *>(buffer.typed_data <double >());
53+ case xla::ffi::DataType::S64:
54+ return reinterpret_cast <void *>(buffer.typed_data <int64_t >());
55+ case xla::ffi::DataType::U8:
56+ return reinterpret_cast <void *>(buffer.typed_data <uint8_t >());
57+ default :
58+ throw logic_error (" Unsupported tensor datatype!" );
59+ }
60+ }
61+
62+ inline int byte_count (ffi::AnyBuffer &buffer) {
63+ switch (buffer.element_type ()) {
64+ case xla::ffi::DataType::F32:
65+ return 4 ;
66+ case xla::ffi::DataType::F64:
67+ return 8 ;
68+ case xla::ffi::DataType::S64:
69+ return 8 ;
70+ case xla::ffi::DataType::U8:
71+ return 1 ;
72+ default :
73+ throw logic_error (" Unsupported tensor datatype!" );
74+ }
5875}
5976
6077inline void * data_ptr (ffi::Result<ffi::AnyBuffer> &buffer) {
6178 return data_ptr (*buffer);
6279}
6380
81+ #ifdef CUDA_BACKEND
82+ void zero_buffer (ffi::AnyBuffer &buffer) {
83+ cudaMemset (
84+ data_ptr (buffer),
85+ 0 ,
86+ buffer.element_count () * byte_count (buffer));
87+ }
88+ #endif
89+
6490struct KernelProp {
6591 int64_t L1_dim, L2_dim, L3_dim, weight_numel;
6692 bool shared_weights;
@@ -244,7 +270,7 @@ ffi::Error tp_backward_impl(
244270 }
245271
246272 if (k.shared_weights ) {
247- // Need to zero out W_grad
273+ zero_buffer (* W_grad);
248274 }
249275
250276 jit_kernel->backward (
0 commit comments