Skip to content

Commit 259ea20

Browse files
committed
JAX example.
1 parent 1bcea33 commit 259ea20

1 file changed

Lines changed: 40 additions & 2 deletions

File tree

openequivariance_extjax/src/libjax_tp_jit.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,44 @@ xla::ffi::DataType enum_to_xla_dtype(int64_t i){
4545
throw logic_error("Unsupported tensor datatype!");
4646
}
4747

48+
std::string xla_dtype_to_string(xla::ffi::DataType dtype) {
49+
const std::unordered_map<xla::ffi::DataType, std::string> map = {
50+
{xla::ffi::DataType::INVALID, "INVALID"},
51+
{xla::ffi::DataType::PRED, "PRED"},
52+
{xla::ffi::DataType::S1, "S1"},
53+
{xla::ffi::DataType::S2, "S2"},
54+
{xla::ffi::DataType::S4, "S4"},
55+
{xla::ffi::DataType::S8, "S8"},
56+
{xla::ffi::DataType::S16, "S16"},
57+
{xla::ffi::DataType::S32, "S32"},
58+
{xla::ffi::DataType::S64, "S64"},
59+
{xla::ffi::DataType::U1, "U1"},
60+
{xla::ffi::DataType::U2, "U2"},
61+
{xla::ffi::DataType::U4, "U4"},
62+
{xla::ffi::DataType::U8, "U8"},
63+
{xla::ffi::DataType::U16, "U16"},
64+
{xla::ffi::DataType::U32, "U32"},
65+
{xla::ffi::DataType::U64, "U64"},
66+
{xla::ffi::DataType::F16, "F16"},
67+
{xla::ffi::DataType::F32, "F32"},
68+
{xla::ffi::DataType::F64, "F64"},
69+
{xla::ffi::DataType::BF16, "BF16"},
70+
{xla::ffi::DataType::C64, "C64"},
71+
{xla::ffi::DataType::C128, "C128"},
72+
{xla::ffi::DataType::TOKEN, "TOKEN"},
73+
{xla::ffi::DataType::F8E5M2, "F8E5M2"},
74+
{xla::ffi::DataType::F8E4M3, "F8E4M3"},
75+
{xla::ffi::DataType::F8E4M3FN, "F8E4M3FN"},
76+
{xla::ffi::DataType::F8E4M3B11FNUZ, "F8E4M3B11FNUZ"},
77+
{xla::ffi::DataType::F8E5M2FNUZ, "F8E5M2FNUZ"},
78+
{xla::ffi::DataType::F8E4M3FNUZ, "F8E4M3FNUZ"},
79+
{xla::ffi::DataType::F8E3M4, "F8E3M4"},
80+
{xla::ffi::DataType::F4E2M1FN, "F4E2M1FN"},
81+
{xla::ffi::DataType::F8E8M0FNU, "F8E8M0FNU"},
82+
};
83+
return map.at(dtype);
84+
}
85+
4886
inline void* data_ptr(ffi::AnyBuffer &buffer) {
4987
return buffer.untyped_data();
5088
}
@@ -237,8 +275,8 @@ inline void check_tensor(const ffi::AnyBuffer &buffer,
237275

238276
if (buffer.element_type() != expected_dtype) {
239277
throw std::logic_error("Datatype mismatch for tensor " + tensor_name +
240-
". Expected datatype " + std::to_string(static_cast<int64_t>(expected_dtype)) +
241-
", got " + std::to_string(static_cast<int64_t>(buffer.element_type())));
278+
". Expected datatype " + xla_dtype_to_string(expected_dtype) +
279+
", got " + xla_dtype_to_string(buffer.element_type()));
242280
}
243281
}
244282

0 commit comments

Comments
 (0)