Skip to content

Commit ad6573b

Browse files
committed
Update linear prefix cache interval to 64 in configuration and documentation; add related tests. This change enhances memory management for hybrid models by increasing the checkpoint interval, which may reduce memory usage but requires more recompute after prefix hits.
1 parent ec5e8aa commit ad6573b

5 files changed

Lines changed: 99 additions & 15 deletions

File tree

lmdeploy/cli/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -561,9 +561,10 @@ def linear_prefix_cache_interval_blocks(parser):
561561

562562
return parser.add_argument('--linear-prefix-cache-interval-blocks',
563563
type=int,
564-
default=2,
565-
help='Checkpoint interval, in KV cache blocks, for hybrid '
566-
'linear-attention prefix state reuse')
564+
default=64,
565+
help='Hybrid linear-attention prefix checkpoint interval in '
566+
'KV cache blocks. Larger values reduce GDN checkpoint memory '
567+
'usage but increase recompute after a prefix hit')
567568

568569
@staticmethod
569570
def num_tokens_per_iter(parser):

lmdeploy/messages.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,11 @@ class TurbomindEngineConfig:
229229
a k/v block, default to 64
230230
enable_prefix_caching: enable cache prompts for block reuse,
231231
default to False
232-
linear_prefix_cache_interval_blocks: checkpoint interval, in KV
233-
cache blocks, for linear-attention prefix states. Applies only to
234-
hybrid models with prefix caching enabled. Default to 2
232+
linear_prefix_cache_interval_blocks: hybrid linear-attention prefix
233+
checkpoint interval, in KV cache blocks. Larger values reduce GDN
234+
checkpoint memory but may require more recompute after a prefix
235+
hit. Applies only to hybrid models with prefix caching enabled.
236+
Default to 64
235237
quant_policy: default to 0. When k/v is quantized into 4 or 8
236238
bit, set it to 4 or 8, respectively
237239
rope_scaling_factor: scaling factor used for dynamic ntk,
@@ -281,7 +283,7 @@ class TurbomindEngineConfig:
281283
cache_chunk_size: int = -1
282284
cache_block_seq_len: int = 64
283285
enable_prefix_caching: bool = False
284-
linear_prefix_cache_interval_blocks: int = 2
286+
linear_prefix_cache_interval_blocks: int = 64
285287
quant_policy: int = 0
286288
rope_scaling_factor: float = 0.0
287289
use_logn_attn: bool = False

src/turbomind/models/llama/SequenceManager.cc

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,71 @@ std::string vector2string(const std::vector<T>& data)
3131
return ss.str();
3232
}
3333

34+
namespace hybrid_prefix_budget {
35+
36+
size_t GetShapeBytes(const std::vector<ssize_t>& shape, DataType dtype)
37+
{
38+
if (shape.empty()) {
39+
return 0;
40+
}
41+
42+
size_t numel = 1;
43+
for (const auto dim : shape) {
44+
TM_CHECK_GT(dim, 0);
45+
numel *= static_cast<size_t>(dim);
46+
}
47+
return static_cast<size_t>(byte_size(dtype, static_cast<std::ptrdiff_t>(numel)));
48+
}
49+
50+
size_t GetHybridCheckpointSlotBytes(const std::vector<ssize_t>& conv_state_shape,
51+
DataType conv_state_dtype,
52+
const std::vector<ssize_t>& recurrent_state_shape,
53+
DataType recurrent_state_dtype)
54+
{
55+
return GetShapeBytes(conv_state_shape, conv_state_dtype)
56+
+ GetShapeBytes(recurrent_state_shape, recurrent_state_dtype);
57+
}
58+
59+
unsigned __int128 GetHybridCacheBytesForBlocks(size_t kv_block_count,
60+
size_t kv_block_bytes,
61+
size_t checkpoint_slot_bytes,
62+
int checkpoint_interval_blocks)
63+
{
64+
const size_t checkpoint_slots = checkpoint_interval_blocks > 0 ? kv_block_count / checkpoint_interval_blocks : 0;
65+
return static_cast<unsigned __int128>(kv_block_count) * kv_block_bytes
66+
+ static_cast<unsigned __int128>(checkpoint_slots) * checkpoint_slot_bytes;
67+
}
68+
69+
size_t SolveHybridKvBlockCount(size_t budget_bytes,
70+
size_t kv_block_bytes,
71+
size_t checkpoint_slot_bytes,
72+
int checkpoint_interval_blocks)
73+
{
74+
if (!budget_bytes || !kv_block_bytes) {
75+
return 0;
76+
}
77+
if (!checkpoint_slot_bytes || checkpoint_interval_blocks <= 0) {
78+
return budget_bytes / kv_block_bytes;
79+
}
80+
81+
size_t lo = 0;
82+
size_t hi = budget_bytes / kv_block_bytes;
83+
while (lo < hi) {
84+
const size_t mid = lo + (hi - lo + 1) / 2;
85+
const auto required_bytes =
86+
GetHybridCacheBytesForBlocks(mid, kv_block_bytes, checkpoint_slot_bytes, checkpoint_interval_blocks);
87+
if (required_bytes <= budget_bytes) {
88+
lo = mid;
89+
}
90+
else {
91+
hi = mid - 1;
92+
}
93+
}
94+
return lo;
95+
}
96+
97+
} // namespace hybrid_prefix_budget
98+
3499
SequenceManager::SequenceManager(const ModelParam& model_param,
35100
DataType runtime_dtype,
36101
int cache_block_seq_len,
@@ -116,7 +181,15 @@ SequenceManager::SequenceManager(const ModelParam& model_param,
116181
block::Layout layout{block_config};
117182
// dump(layout);
118183

119-
size_t block_size = layout.block_size(cache_layer_num);
184+
size_t block_size = layout.block_size(cache_layer_num);
185+
const bool has_linear_prefix_checkpoints = enable_prefix_caching && num_linear_layers > 0;
186+
const size_t linear_prefix_slot_bytes =
187+
has_linear_prefix_checkpoints ?
188+
hybrid_prefix_budget::GetHybridCheckpointSlotBytes(linear_prefix_conv_state_shape,
189+
linear_prefix_conv_state_dtype,
190+
linear_prefix_recurrent_state_shape,
191+
linear_prefix_recurrent_state_dtype) :
192+
0;
120193

121194
if (num_linear_layers > 0 && block_count < 1.) {
122195
const size_t target_bytes = static_cast<size_t>(free_before * block_count);
@@ -128,13 +201,21 @@ SequenceManager::SequenceManager(const ModelParam& model_param,
128201
<< "Please decrease max_batch_size to reduce total linear state size or increase cache_max_entry_count.";
129202
}
130203
const size_t cache_bytes = target_bytes - live_linear_bytes;
131-
block_count = static_cast<double>(cache_bytes) / static_cast<double>(block_size);
204+
block_count =
205+
has_linear_prefix_checkpoints ? static_cast<double>(hybrid_prefix_budget::SolveHybridKvBlockCount(
206+
cache_bytes, block_size, linear_prefix_slot_bytes, linear_prefix_cache_interval_blocks_)) :
207+
static_cast<double>(cache_bytes) / static_cast<double>(block_size);
132208
}
133209
else if (num_linear_layers > 0 && block_count >= 1.) {
134-
const size_t requested_blocks = static_cast<size_t>(block_count);
135-
const size_t requested_cache_bytes = requested_blocks * block_size;
136-
const size_t available_after_live = get_free_size();
137-
TM_CHECK_LE(requested_cache_bytes, available_after_live) << "Insufficient memory for KV cache blocks.";
210+
const size_t requested_blocks = static_cast<size_t>(block_count);
211+
const auto requested_cache_bytes =
212+
has_linear_prefix_checkpoints ? hybrid_prefix_budget::GetHybridCacheBytesForBlocks(
213+
requested_blocks, block_size, linear_prefix_slot_bytes, linear_prefix_cache_interval_blocks_) :
214+
static_cast<unsigned __int128>(requested_blocks) * block_size;
215+
const size_t available_after_live = get_free_size();
216+
TM_CHECK(requested_cache_bytes <= static_cast<unsigned __int128>(available_after_live))
217+
<< "Insufficient memory for "
218+
<< (has_linear_prefix_checkpoints ? "hybrid prefix cache blocks and checkpoints." : "KV cache blocks.");
138219
}
139220

140221
block_manager_ = std::make_shared<BlockManager>(block_size, block_count, chunk_size, allocator, get_free_size);

src/turbomind/turbomind.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ TurboMind::Impl::Impl(string model_dir, string config, FFICtxFactory ffi_ctx_fac
443443
engine_param_.cache_max_block_count = engine["cache_max_entry_count"].as<float>(0);
444444
engine_param_.cache_chunk_size = engine["cache_chunk_size"].as<int>(0);
445445
engine_param_.enable_prefix_caching = engine["enable_prefix_caching"].as<bool>(false);
446-
engine_param_.linear_prefix_cache_interval_blocks = engine["linear_prefix_cache_interval_blocks"].as<int>(2);
446+
engine_param_.linear_prefix_cache_interval_blocks = engine["linear_prefix_cache_interval_blocks"].as<int>(64);
447447
engine_param_.enable_metrics = engine["enable_metrics"].as<bool>(false);
448448
TM_CHECK_GE(engine_param_.linear_prefix_cache_interval_blocks, 1);
449449

tests/test_lmdeploy/test_turbomind/test_engine_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
def test_linear_prefix_cache_interval_blocks_default():
88
config = TurbomindEngineConfig(enable_prefix_caching=True)
9-
assert config.linear_prefix_cache_interval_blocks == 2
9+
assert config.linear_prefix_cache_interval_blocks == 64
1010

1111

1212
def test_linear_prefix_cache_interval_blocks_validation():

0 commit comments

Comments
 (0)