Skip to content

Commit 136c9f6

Browse files
committed
Zero'd buffer.
1 parent d1131fa commit 136c9f6

1 file changed

Lines changed: 41 additions & 15 deletions

File tree

openequivariance_extjax/src/libjax_tp_jit.cpp

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#include "nanobind/nanobind.h"
1212
#include "xla/ffi/api/ffi.h"
1313

14+
namespace nb = nanobind;
15+
namespace ffi = xla::ffi;
16+
1417
#define CUDA_BACKEND // Stick to CUDA for now
1518

1619
#ifdef CUDA_BACKEND
@@ -20,14 +23,11 @@
2023
using GPU_Allocator = CUDA_Allocator;
2124

2225
template<typename T>
23-
using GroupMM = GroupMMCUDA<T>;
26+
using GroupMM = GroupMMCUDA<T>;
2427
#endif
2528

2629
#include "tensorproducts.hpp"
2730

28-
namespace nb = nanobind;
29-
namespace ffi = xla::ffi;
30-
3131
xla::ffi::DataType enum_to_xla_dtype(int64_t i){
3232
switch(i) {
3333
case 1:
@@ -45,22 +45,48 @@ xla::ffi::DataType enum_to_xla_dtype(int64_t i){
4545
}
4646

4747
inline void* data_ptr(ffi::AnyBuffer &buffer) {
48-
if(buffer.element_type() == xla::ffi::DataType::F32)
49-
return reinterpret_cast<void*>(buffer.typed_data<float>());
50-
else if(buffer.element_type() == xla::ffi::DataType::F64)
51-
return reinterpret_cast<void*>(buffer.typed_data<double>());
52-
else if(buffer.element_type() == xla::ffi::DataType::S64)
53-
return reinterpret_cast<void*>(buffer.typed_data<int64_t>());
54-
else if(buffer.element_type() == xla::ffi::DataType::U8)
55-
return reinterpret_cast<void*>(buffer.typed_data<uint8_t>());
56-
else
57-
throw logic_error("Unsupported tensor datatype!");
48+
switch (buffer.element_type()) {
49+
case xla::ffi::DataType::F32:
50+
return reinterpret_cast<void*>(buffer.typed_data<float>());
51+
case xla::ffi::DataType::F64:
52+
return reinterpret_cast<void*>(buffer.typed_data<double>());
53+
case xla::ffi::DataType::S64:
54+
return reinterpret_cast<void*>(buffer.typed_data<int64_t>());
55+
case xla::ffi::DataType::U8:
56+
return reinterpret_cast<void*>(buffer.typed_data<uint8_t>());
57+
default:
58+
throw logic_error("Unsupported tensor datatype!");
59+
}
60+
}
61+
62+
inline int byte_count(ffi::AnyBuffer &buffer) {
63+
switch (buffer.element_type()) {
64+
case xla::ffi::DataType::F32:
65+
return 4;
66+
case xla::ffi::DataType::F64:
67+
return 8;
68+
case xla::ffi::DataType::S64:
69+
return 8;
70+
case xla::ffi::DataType::U8:
71+
return 1;
72+
default:
73+
throw logic_error("Unsupported tensor datatype!");
74+
}
5875
}
5976

6077
inline void* data_ptr(ffi::Result<ffi::AnyBuffer> &buffer) {
6178
return data_ptr(*buffer);
6279
}
6380

81+
#ifdef CUDA_BACKEND
82+
void zero_buffer(ffi::AnyBuffer &buffer) {
83+
cudaMemset(
84+
data_ptr(buffer),
85+
0,
86+
buffer.element_count() * byte_count(buffer));
87+
}
88+
#endif
89+
6490
struct KernelProp {
6591
int64_t L1_dim, L2_dim, L3_dim, weight_numel;
6692
bool shared_weights;
@@ -244,7 +270,7 @@ ffi::Error tp_backward_impl(
244270
}
245271

246272
if (k.shared_weights) {
247-
// Need to zero out W_grad
273+
zero_buffer(*W_grad);
248274
}
249275

250276
jit_kernel->backward(

0 commit comments

Comments
 (0)