@@ -45,6 +45,44 @@ xla::ffi::DataType enum_to_xla_dtype(int64_t i){
4545 throw logic_error (" Unsupported tensor datatype!" );
4646}
4747
48+ std::string xla_dtype_to_string (xla::ffi::DataType dtype) {
49+ const std::unordered_map<xla::ffi::DataType, std::string> map = {
50+ {xla::ffi::DataType::INVALID, " INVALID" },
51+ {xla::ffi::DataType::PRED, " PRED" },
52+ {xla::ffi::DataType::S1, " S1" },
53+ {xla::ffi::DataType::S2, " S2" },
54+ {xla::ffi::DataType::S4, " S4" },
55+ {xla::ffi::DataType::S8, " S8" },
56+ {xla::ffi::DataType::S16, " S16" },
57+ {xla::ffi::DataType::S32, " S32" },
58+ {xla::ffi::DataType::S64, " S64" },
59+ {xla::ffi::DataType::U1, " U1" },
60+ {xla::ffi::DataType::U2, " U2" },
61+ {xla::ffi::DataType::U4, " U4" },
62+ {xla::ffi::DataType::U8, " U8" },
63+ {xla::ffi::DataType::U16, " U16" },
64+ {xla::ffi::DataType::U32, " U32" },
65+ {xla::ffi::DataType::U64, " U64" },
66+ {xla::ffi::DataType::F16, " F16" },
67+ {xla::ffi::DataType::F32, " F32" },
68+ {xla::ffi::DataType::F64, " F64" },
69+ {xla::ffi::DataType::BF16, " BF16" },
70+ {xla::ffi::DataType::C64, " C64" },
71+ {xla::ffi::DataType::C128, " C128" },
72+ {xla::ffi::DataType::TOKEN, " TOKEN" },
73+ {xla::ffi::DataType::F8E5M2, " F8E5M2" },
74+ {xla::ffi::DataType::F8E4M3, " F8E4M3" },
75+ {xla::ffi::DataType::F8E4M3FN, " F8E4M3FN" },
76+ {xla::ffi::DataType::F8E4M3B11FNUZ, " F8E4M3B11FNUZ" },
77+ {xla::ffi::DataType::F8E5M2FNUZ, " F8E5M2FNUZ" },
78+ {xla::ffi::DataType::F8E4M3FNUZ, " F8E4M3FNUZ" },
79+ {xla::ffi::DataType::F8E3M4, " F8E3M4" },
80+ {xla::ffi::DataType::F4E2M1FN, " F4E2M1FN" },
81+ {xla::ffi::DataType::F8E8M0FNU, " F8E8M0FNU" },
82+ };
83+ return map.at (dtype);
84+ }
85+
4886inline void * data_ptr (ffi::AnyBuffer &buffer) {
4987 return buffer.untyped_data ();
5088}
@@ -237,8 +275,8 @@ inline void check_tensor(const ffi::AnyBuffer &buffer,
237275
238276 if (buffer.element_type () != expected_dtype) {
239277 throw std::logic_error (" Datatype mismatch for tensor " + tensor_name +
240- " . Expected datatype " + std::to_string ( static_cast < int64_t >( expected_dtype)) +
241- " , got " + std::to_string ( static_cast < int64_t >( buffer.element_type () )));
278+ " . Expected datatype " + xla_dtype_to_string ( expected_dtype) +
279+ " , got " + xla_dtype_to_string ( buffer.element_type ()));
242280 }
243281}
244282
0 commit comments