Skip to content

Commit 4576910

Browse files
unamedkrclaude
andcommitted
turbo_kv_4bo: Variant G — 4-bit codebook + 8 per-block FP16 outliers
New TQ_TYPE_TURBO_KV_4BO type. Each block stores the 8 channels with the largest |rotated[i]| as exact FP16 values (with their indices) on top of the existing Variant F 4-bit Lloyd-Max codebook. At dequant time these 8 positions are overwritten with the stored exact values, eliminating the worst quantization errors per block. This is a simpler, local form of the per-channel outlier handling described in the Google TurboQuant paper. Llama 3.2 3B PPL on bench/data/ppl_1k.txt (FP32 = 13.56): turbo_kv_4b 14.28 (+5.3%) ← 72B turbo_kv_4bo 13.86 (+2.2%) ← 96B ← gap cut by 58% turbo_kv_5b 13.60 (+0.34%) ← 88B SmolLM2 135M PPL (FP32 = 18.62): turbo_kv_4b 19.70 (+5.8%) turbo_kv_4bo 19.29 (+3.6%) ← gap cut by 38% turbo_kv_5b 18.94 (+1.7%) The technique works (validates Issue #15's per-channel outlier hypothesis), but at 96 bytes the variant is currently bigger than 5b (88B) without matching its quality. Next iteration will combine outliers with a 3-bit base codebook (turbo_kv_3bo, ~80 bytes) to test whether outliers + smaller base can beat 5b at smaller block size. Block layout (96 bytes): norm(2) + residual_norm(2) + inv_std(2) + _pad(2) mse_indices[64] // 4-bit packed (Variant F base) out_indices[8] // 1 byte per outlier out_values[8] // FP16 per outlier Quantize finds top-K outliers by |rotated| and stores them verbatim. The codebook scaling uses BODY-only max-abs (excluding outliers) so the codebook doesn't waste resolution on the tails the outliers already capture exactly. 35/35 tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 475872c commit 4576910

5 files changed

Lines changed: 235 additions & 2 deletions

File tree

include/turboquant/tq_types.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ typedef enum {
5555
TQ_TYPE_TURBO_KV_2B = 11,/* TurboQuant KV: 2-bit (1-bit codebook + 1-bit QJL) */
5656
TQ_TYPE_UNIFORM_3B= 12, /* Min-Max uniform 3-bit with sub-block scales */
5757
TQ_TYPE_TURBO_KV_5B = 13,/* TurboQuant KV: RHT + 5-bit Lloyd-Max codebook */
58-
TQ_TYPE_COUNT = 14
58+
TQ_TYPE_TURBO_KV_4BO = 14,/* TurboQuant KV: 4-bit codebook + 8 FP16 outliers */
59+
TQ_TYPE_COUNT = 15
5960
} tq_type;
6061

6162
/* ============================================================
@@ -221,6 +222,29 @@ typedef struct {
221222
uint8_t mse_indices[TQ_BK * 3 / 8]; /* 3-bit packed codebook indices (48B) */
222223
} block_tq_turbo_kv_3b;
223224

225+
/* TurboQuant KV cache block: 4-bit + per-block outliers (Variant G)
226+
*
227+
* Same Variant F base (RHT + 4-bit Lloyd-Max codebook), plus a per-block
228+
* outlier list: the K=8 largest |rotated[i]| values are stored verbatim
229+
* as FP16 with their channel index, and OVERWRITE the codebook
230+
* reconstruction at dequantize time. This addresses the heavy-tail
231+
* problem the Google TurboQuant paper handles via per-channel bit
232+
* allocation, but in a simpler local form.
233+
*
234+
* Layout: 8 hdr + 64 mse_4bit + 8 out_idx + 16 out_val_fp16 = 96 bytes
235+
*/
236+
#define TQ_KV_4BO_OUTLIERS 8
237+
238+
typedef struct {
239+
uint16_t norm; /* L2 norm of original (fp16) */
240+
uint16_t residual_norm; /* unused */
241+
uint16_t inv_std_fp16; /* per-block inv_std */
242+
uint16_t _pad; /* alignment */
243+
uint8_t mse_indices[TQ_BK / 2]; /* 4-bit packed indices (64B) */
244+
uint8_t out_indices[TQ_KV_4BO_OUTLIERS]; /* outlier channel indices (8B) */
245+
uint16_t out_values[TQ_KV_4BO_OUTLIERS]; /* outlier values FP16 (16B) */
246+
} block_tq_turbo_kv_4bo;
247+
224248
/* TurboQuant KV cache block: 5-bit variant (Variant F architecture)
225249
*
226250
* 5-bit (32-level) Lloyd-Max-Gaussian codebook on RHT-rotated values.
@@ -295,6 +319,7 @@ TQ_CHECK_SIZE(block_tq_mixed_4b8, 4 + TQ_MIXED_OUTLIERS + TQ_MIXED_OUTLIERS * 2
295319
TQ_CHECK_SIZE(block_tq_turbo_kv_3b, 8 + TQ_BK * 3 / 8);
296320
TQ_CHECK_SIZE(block_tq_turbo_kv_4b, 8 + TQ_BK / 2);
297321
TQ_CHECK_SIZE(block_tq_turbo_kv_5b, 8 + TQ_BK * 5 / 8);
322+
TQ_CHECK_SIZE(block_tq_turbo_kv_4bo, 8 + TQ_BK / 2 + TQ_KV_4BO_OUTLIERS + TQ_KV_4BO_OUTLIERS * 2);
298323
TQ_CHECK_SIZE(block_tq_turbo_kv_1b, 8 + TQ_BK / 8);
299324
TQ_CHECK_SIZE(block_tq_turbo_kv_2b, 8 + TQ_BK / 8 + TQ_BK / 8);
300325

integrations/llamacpp/tq_kv_cache.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ enum {
4646
GGML_TYPE_TQ_TURBO_KV_2B = GGML_TYPE_TQ_BASE + 11,
4747
GGML_TYPE_TQ_UNIFORM_3B = GGML_TYPE_TQ_BASE + 12,
4848
GGML_TYPE_TQ_TURBO_KV_5B = GGML_TYPE_TQ_BASE + 13,
49-
GGML_TYPE_TQ_COUNT = 14,
49+
GGML_TYPE_TQ_TURBO_KV_4BO = GGML_TYPE_TQ_BASE + 14,
50+
GGML_TYPE_TQ_COUNT = 15,
5051
};
5152

5253
/* ============================================================
@@ -69,6 +70,7 @@ static int tq_to_ggml_type(tq_type type) {
6970
case TQ_TYPE_TURBO_KV_2B: return GGML_TYPE_TQ_TURBO_KV_2B;
7071
case TQ_TYPE_UNIFORM_3B: return GGML_TYPE_TQ_UNIFORM_3B;
7172
case TQ_TYPE_TURBO_KV_5B: return GGML_TYPE_TQ_TURBO_KV_5B;
73+
case TQ_TYPE_TURBO_KV_4BO: return GGML_TYPE_TQ_TURBO_KV_4BO;
7274
default: return -1;
7375
}
7476
}
@@ -89,6 +91,7 @@ static tq_type ggml_to_tq_type(int ggml_id) {
8991
case GGML_TYPE_TQ_TURBO_KV_2B: return TQ_TYPE_TURBO_KV_2B;
9092
case GGML_TYPE_TQ_UNIFORM_3B: return TQ_TYPE_UNIFORM_3B;
9193
case GGML_TYPE_TQ_TURBO_KV_5B: return TQ_TYPE_TURBO_KV_5B;
94+
case GGML_TYPE_TQ_TURBO_KV_4BO: return TQ_TYPE_TURBO_KV_4BO;
9295
default: return TQ_TYPE_COUNT;
9396
}
9497
}
@@ -155,6 +158,7 @@ TQ_GGML_WRAPPERS(turbo_kv_1b, TQ_TYPE_TURBO_KV_1B)
155158
TQ_GGML_WRAPPERS(turbo_kv_2b, TQ_TYPE_TURBO_KV_2B)
156159
TQ_GGML_WRAPPERS(uniform_3b, TQ_TYPE_UNIFORM_3B)
157160
TQ_GGML_WRAPPERS(turbo_kv_5b, TQ_TYPE_TURBO_KV_5B)
161+
TQ_GGML_WRAPPERS(turbo_kv_4bo, TQ_TYPE_TURBO_KV_4BO)
158162

159163
/* ============================================================
160164
* vec_dot wrappers (quantized key . FP32 query -> scalar)
@@ -209,6 +213,7 @@ TQ_GGML_VEC_DOT(turbo_kv_1b, TQ_TYPE_TURBO_KV_1B)
209213
TQ_GGML_VEC_DOT(turbo_kv_2b, TQ_TYPE_TURBO_KV_2B)
210214
TQ_GGML_VEC_DOT(uniform_3b, TQ_TYPE_UNIFORM_3B)
211215
TQ_GGML_VEC_DOT(turbo_kv_5b, TQ_TYPE_TURBO_KV_5B)
216+
TQ_GGML_VEC_DOT(turbo_kv_4bo, TQ_TYPE_TURBO_KV_4BO)
212217

213218
/* ============================================================
214219
* GGML type trait table
@@ -340,6 +345,14 @@ static const tq_ggml_type_trait TQ_GGML_TRAITS[GGML_TYPE_TQ_COUNT] = {
340345
tq_ggml_to_float_turbo_kv_5b,
341346
tq_ggml_vec_dot_turbo_kv_5b,
342347
},
348+
{
349+
"tq_turbo_kv_4bo", GGML_TYPE_TQ_TURBO_KV_4BO, TQ_TYPE_TURBO_KV_4BO,
350+
sizeof(block_tq_turbo_kv_4bo), TQ_BK,
351+
(float)sizeof(block_tq_turbo_kv_4bo) * 8.0f / TQ_BK,
352+
tq_ggml_from_float_turbo_kv_4bo,
353+
tq_ggml_to_float_turbo_kv_4bo,
354+
tq_ggml_vec_dot_turbo_kv_4bo,
355+
},
343356
};
344357

345358
#define TQ_GGML_NUM_TYPES (sizeof(TQ_GGML_TRAITS) / sizeof(TQ_GGML_TRAITS[0]))
@@ -432,6 +445,7 @@ tq_type tq_parse_kv_cache_type(const char* arg) {
432445
{ "turbokv3", TQ_TYPE_TURBO_KV_3B },
433446
{ "turbo_kv_4b", TQ_TYPE_TURBO_KV_4B },
434447
{ "turbo_kv_5b", TQ_TYPE_TURBO_KV_5B },
448+
{ "turbo_kv_4bo", TQ_TYPE_TURBO_KV_4BO },
435449
{ "tq-turbo-kv-4b", TQ_TYPE_TURBO_KV_4B },
436450
{ "turbokv4", TQ_TYPE_TURBO_KV_4B },
437451
{ "turbo_kv_1b", TQ_TYPE_TURBO_KV_1B },

src/core/tq_traits.c

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ extern void tq_turbo_kv_5b_dequantize_ref(const void* src, float* dst, int n);
5353
extern void tq_turbo_kv_5b_attention_ref(const float* query, const void* kv,
5454
float* scores, int seq_len, int head_dim);
5555

56+
extern void tq_turbo_kv_4bo_quantize_ref(const float* src, void* dst, int n);
57+
extern void tq_turbo_kv_4bo_dequantize_ref(const void* src, float* dst, int n);
58+
extern void tq_turbo_kv_4bo_attention_ref(const float* query, const void* kv,
59+
float* scores, int seq_len, int head_dim);
60+
5661
extern void tq_turbo_kv_1b_quantize_ref(const float* src, void* dst, int n);
5762
extern void tq_turbo_kv_1b_dequantize_ref(const void* src, float* dst, int n);
5863
extern void tq_turbo_kv_1b_attention_ref(const float* query, const void* kv,
@@ -175,6 +180,16 @@ tq_type_traits_t TQ_TRAITS[TQ_TYPE_COUNT] = {
175180
.attention = tq_turbo_kv_5b_attention_ref,
176181
.residual_type = TQ_TYPE_COUNT,
177182
},
183+
[TQ_TYPE_TURBO_KV_4BO] = {
184+
.name = "turbo_kv_4bo",
185+
.block_size = TQ_BK,
186+
.type_size = sizeof(block_tq_turbo_kv_4bo),
187+
.bpe = (float)sizeof(block_tq_turbo_kv_4bo) * 8.0f / TQ_BK,
188+
.quantize = tq_turbo_kv_4bo_quantize_ref,
189+
.dequantize = tq_turbo_kv_4bo_dequantize_ref,
190+
.attention = tq_turbo_kv_4bo_attention_ref,
191+
.residual_type = TQ_TYPE_COUNT,
192+
},
178193
[TQ_TYPE_TURBO_KV_1B] = {
179194
.name = "turbo_kv_1b",
180195
.block_size = TQ_BK,
@@ -276,6 +291,8 @@ tq_format_spec_t tq_get_format_spec(tq_type type) {
276291
spec.algorithm = TQ_ALG_TURBO; spec.key_bits = 4; break;
277292
case TQ_TYPE_TURBO_KV_5B:
278293
spec.algorithm = TQ_ALG_TURBO; spec.key_bits = 5; break;
294+
case TQ_TYPE_TURBO_KV_4BO:
295+
spec.algorithm = TQ_ALG_TURBO; spec.key_bits = 4; break;
279296
case TQ_TYPE_TURBO_KV_1B:
280297
spec.algorithm = TQ_ALG_TURBO; spec.key_bits = 1; break;
281298
case TQ_TYPE_TURBO_KV_2B:

src/core/tq_turbo_kv.c

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,3 +1052,179 @@ void tq_turbo_kv_5b_attention_ref(const float* query, const void* kv_cache,
10521052
scores[seq] = norm * mse_dot;
10531053
}
10541054
}
1055+
1056+
/* ============================================================
1057+
* TurboQuant KV 4-bit + outliers (Variant G):
1058+
* normalize -> RHT -> 4-bit (16-level) Lloyd-Max codebook
1059+
* + top-K outliers stored verbatim as FP16 with channel index
1060+
*
1061+
* Same Variant F base + per-block outlier list. The K largest |rotated|
1062+
* channels are stored exactly and overwrite the codebook reconstruction
1063+
* at dequant time. Closes more PPL gap than 4b-only without going as
1064+
* heavy as 5b on memory.
1065+
* ============================================================ */
1066+
1067+
void tq_turbo_kv_4bo_quantize_ref(const float* src, void* dst, int n) {
1068+
block_tq_turbo_kv_4bo* block = (block_tq_turbo_kv_4bo*)dst;
1069+
int dim = n;
1070+
if (dim > TQ_BK) dim = TQ_BK;
1071+
1072+
/* Step 1: L2 norm */
1073+
float norm_sq = 0.0f;
1074+
for (int i = 0; i < dim; i++) norm_sq += src[i] * src[i];
1075+
float norm = sqrtf(norm_sq);
1076+
block->norm = tkv_fp32_to_fp16(norm);
1077+
block->residual_norm = 0;
1078+
block->_pad = 0;
1079+
1080+
/* Step 2: Normalize + RHT */
1081+
float rotated[TQ_BK];
1082+
float inv_norm = (norm > 1e-10f) ? (1.0f / norm) : 0.0f;
1083+
for (int i = 0; i < dim; i++) rotated[i] = src[i] * inv_norm;
1084+
for (int i = dim; i < TQ_BK; i++) rotated[i] = 0.0f;
1085+
tq_rht_transform(rotated, dim, TKV_DEFAULT_SEED);
1086+
1087+
/* Step 3: Find top-K outliers by |rotated| (selection sort, K is small) */
1088+
int K = TQ_KV_4BO_OUTLIERS;
1089+
int out_idx[TQ_KV_4BO_OUTLIERS];
1090+
float out_abs[TQ_KV_4BO_OUTLIERS];
1091+
for (int k = 0; k < K; k++) { out_idx[k] = -1; out_abs[k] = -1.0f; }
1092+
1093+
for (int i = 0; i < dim; i++) {
1094+
float a = fabsf(rotated[i]);
1095+
/* Find smallest in current top-K and replace if larger */
1096+
int min_pos = 0;
1097+
for (int k = 1; k < K; k++) {
1098+
if (out_abs[k] < out_abs[min_pos]) min_pos = k;
1099+
}
1100+
if (a > out_abs[min_pos]) {
1101+
out_abs[min_pos] = a;
1102+
out_idx[min_pos] = i;
1103+
}
1104+
}
1105+
1106+
/* Store outlier indices and FP16 values */
1107+
for (int k = 0; k < K; k++) {
1108+
int idx = out_idx[k];
1109+
if (idx < 0) {
1110+
block->out_indices[k] = 0;
1111+
block->out_values[k] = 0;
1112+
} else {
1113+
block->out_indices[k] = (uint8_t)idx;
1114+
block->out_values[k] = tkv_fp32_to_fp16(rotated[idx]);
1115+
}
1116+
}
1117+
1118+
/* Step 4: max-abs scaling on the NON-outlier values for the codebook.
1119+
* Outliers are stored exact, so the codebook only needs to fit the body.
1120+
* Mask outliers out for max-abs computation. */
1121+
char is_outlier[TQ_BK];
1122+
memset(is_outlier, 0, sizeof(is_outlier));
1123+
for (int k = 0; k < K; k++) {
1124+
if (out_idx[k] >= 0) is_outlier[out_idx[k]] = 1;
1125+
}
1126+
1127+
float body_max_abs = 0.0f;
1128+
for (int i = 0; i < dim; i++) {
1129+
if (is_outlier[i]) continue;
1130+
float a = fabsf(rotated[i]);
1131+
if (a > body_max_abs) body_max_abs = a;
1132+
}
1133+
if (body_max_abs < 1e-10f) body_max_abs = 1.0f;
1134+
const float CENT_4BIT_MAX = 2.7326f;
1135+
float inv_std = CENT_4BIT_MAX / body_max_abs;
1136+
block->inv_std_fp16 = tkv_fp32_to_fp16(inv_std);
1137+
1138+
/* Step 5: Quantize all 128 with 4-bit codebook (outlier values get
1139+
* overwritten at dequant time, so their codebook indices don't matter
1140+
* for accuracy — but we still write something so the bytes are defined). */
1141+
uint8_t indices[TQ_BK];
1142+
tq_codebook_quantize(rotated, indices, dim, 4, inv_std);
1143+
memset(block->mse_indices, 0, TQ_BK / 2);
1144+
for (int i = 0; i < dim; i++) {
1145+
int byte_idx = i / 2;
1146+
int bit_pos = (i & 1) * 4;
1147+
block->mse_indices[byte_idx] |= (uint8_t)((indices[i] & 0x0F) << bit_pos);
1148+
}
1149+
}
1150+
1151+
static void dequant_mse_rotated_4bo(const block_tq_turbo_kv_4bo* block,
1152+
float* rotated, int dim) {
1153+
/* 4-bit codebook lookup */
1154+
float inv_std = tkv_fp16_to_fp32(block->inv_std_fp16);
1155+
if (inv_std < 1e-10f) inv_std = sqrtf((float)dim);
1156+
uint8_t indices[TQ_BK];
1157+
for (int i = 0; i < dim; i++) {
1158+
uint8_t b = block->mse_indices[i / 2];
1159+
indices[i] = (i & 1) ? (b >> 4) : (b & 0x0F);
1160+
}
1161+
tq_codebook_dequantize(indices, rotated, dim, 4, inv_std);
1162+
1163+
/* Overwrite outlier positions with stored exact FP16 values */
1164+
int K = TQ_KV_4BO_OUTLIERS;
1165+
for (int k = 0; k < K; k++) {
1166+
int idx = block->out_indices[k];
1167+
if (idx < dim) {
1168+
rotated[idx] = tkv_fp16_to_fp32(block->out_values[k]);
1169+
}
1170+
}
1171+
}
1172+
1173+
void tq_turbo_kv_4bo_dequantize_ref(const void* src, float* dst, int n) {
1174+
const block_tq_turbo_kv_4bo* block = (const block_tq_turbo_kv_4bo*)src;
1175+
int dim = n;
1176+
if (dim > TQ_BK) dim = TQ_BK;
1177+
1178+
float norm = tkv_fp16_to_fp32(block->norm);
1179+
float rotated[TQ_BK];
1180+
dequant_mse_rotated_4bo(block, rotated, dim);
1181+
tq_rht_inverse(rotated, dim, TKV_DEFAULT_SEED);
1182+
for (int i = 0; i < dim; i++) dst[i] = rotated[i] * norm;
1183+
}
1184+
1185+
void tq_turbo_kv_4bo_attention_ref(const float* query, const void* kv_cache,
1186+
float* scores, int seq_len, int head_dim) {
1187+
const block_tq_turbo_kv_4bo* blocks_4bo = (const block_tq_turbo_kv_4bo*)kv_cache;
1188+
int dim = head_dim;
1189+
if (dim > TQ_BK) dim = TQ_BK;
1190+
1191+
/* Pre-rotate query once */
1192+
float q_rot[TQ_BK];
1193+
memcpy(q_rot, query, (size_t)dim * sizeof(float));
1194+
for (int i = dim; i < TQ_BK; i++) q_rot[i] = 0.0f;
1195+
tq_rht_transform(q_rot, dim, TKV_DEFAULT_SEED);
1196+
1197+
for (int seq = 0; seq < seq_len; seq++) {
1198+
const block_tq_turbo_kv_4bo* block = &blocks_4bo[seq];
1199+
float norm = tkv_fp16_to_fp32(block->norm);
1200+
1201+
float rotated[TQ_BK];
1202+
dequant_mse_rotated_4bo(block, rotated, dim);
1203+
1204+
float mse_dot = 0.0f;
1205+
#ifdef __ARM_NEON
1206+
{
1207+
float32x4_t acc0 = vdupq_n_f32(0.0f);
1208+
float32x4_t acc1 = vdupq_n_f32(0.0f);
1209+
float32x4_t acc2 = vdupq_n_f32(0.0f);
1210+
float32x4_t acc3 = vdupq_n_f32(0.0f);
1211+
int d = 0;
1212+
for (; d + 15 < dim; d += 16) {
1213+
acc0 = vfmaq_f32(acc0, vld1q_f32(&q_rot[d]), vld1q_f32(&rotated[d]));
1214+
acc1 = vfmaq_f32(acc1, vld1q_f32(&q_rot[d + 4]), vld1q_f32(&rotated[d + 4]));
1215+
acc2 = vfmaq_f32(acc2, vld1q_f32(&q_rot[d + 8]), vld1q_f32(&rotated[d + 8]));
1216+
acc3 = vfmaq_f32(acc3, vld1q_f32(&q_rot[d + 12]), vld1q_f32(&rotated[d + 12]));
1217+
}
1218+
acc0 = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
1219+
for (; d + 3 < dim; d += 4) {
1220+
acc0 = vfmaq_f32(acc0, vld1q_f32(&q_rot[d]), vld1q_f32(&rotated[d]));
1221+
}
1222+
mse_dot = vaddvq_f32(acc0);
1223+
for (; d < dim; d++) mse_dot += q_rot[d] * rotated[d];
1224+
}
1225+
#else
1226+
for (int d = 0; d < dim; d++) mse_dot += q_rot[d] * rotated[d];
1227+
#endif
1228+
scores[seq] = norm * mse_dot;
1229+
}
1230+
}

tools/quant.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ static tq_type parse_kv_type(const char* s) {
8282
if (strcmp(s, "turbo_kv_3b") == 0) return TQ_TYPE_TURBO_KV_3B;
8383
if (strcmp(s, "turbo_kv_4b") == 0) return TQ_TYPE_TURBO_KV_4B;
8484
if (strcmp(s, "turbo_kv_5b") == 0) return TQ_TYPE_TURBO_KV_5B;
85+
if (strcmp(s, "turbo_kv_4bo") == 0) return TQ_TYPE_TURBO_KV_4BO;
8586
if (strcmp(s, "turbo_kv_1b") == 0) return TQ_TYPE_TURBO_KV_1B;
8687
if (strcmp(s, "qjl_1b") == 0) return TQ_TYPE_QJL_1B;
8788
if (strcmp(s, "mixed_4b8") == 0) return TQ_TYPE_MIXED_4B8;

0 commit comments

Comments
 (0)