@@ -31,6 +31,71 @@ std::string vector2string(const std::vector<T>& data)
3131 return ss.str ();
3232}
3333
34+ namespace hybrid_prefix_budget {
35+
36+ size_t GetShapeBytes (const std::vector<ssize_t >& shape, DataType dtype)
37+ {
38+ if (shape.empty ()) {
39+ return 0 ;
40+ }
41+
42+ size_t numel = 1 ;
43+ for (const auto dim : shape) {
44+ TM_CHECK_GT (dim, 0 );
45+ numel *= static_cast <size_t >(dim);
46+ }
47+ return static_cast <size_t >(byte_size (dtype, static_cast <std::ptrdiff_t >(numel)));
48+ }
49+
50+ size_t GetHybridCheckpointSlotBytes (const std::vector<ssize_t >& conv_state_shape,
51+ DataType conv_state_dtype,
52+ const std::vector<ssize_t >& recurrent_state_shape,
53+ DataType recurrent_state_dtype)
54+ {
55+ return GetShapeBytes (conv_state_shape, conv_state_dtype)
56+ + GetShapeBytes (recurrent_state_shape, recurrent_state_dtype);
57+ }
58+
59+ unsigned __int128 GetHybridCacheBytesForBlocks (size_t kv_block_count,
60+ size_t kv_block_bytes,
61+ size_t checkpoint_slot_bytes,
62+ int checkpoint_interval_blocks)
63+ {
64+ const size_t checkpoint_slots = checkpoint_interval_blocks > 0 ? kv_block_count / checkpoint_interval_blocks : 0 ;
65+ return static_cast <unsigned __int128>(kv_block_count) * kv_block_bytes
66+ + static_cast <unsigned __int128>(checkpoint_slots) * checkpoint_slot_bytes;
67+ }
68+
69+ size_t SolveHybridKvBlockCount (size_t budget_bytes,
70+ size_t kv_block_bytes,
71+ size_t checkpoint_slot_bytes,
72+ int checkpoint_interval_blocks)
73+ {
74+ if (!budget_bytes || !kv_block_bytes) {
75+ return 0 ;
76+ }
77+ if (!checkpoint_slot_bytes || checkpoint_interval_blocks <= 0 ) {
78+ return budget_bytes / kv_block_bytes;
79+ }
80+
81+ size_t lo = 0 ;
82+ size_t hi = budget_bytes / kv_block_bytes;
83+ while (lo < hi) {
84+ const size_t mid = lo + (hi - lo + 1 ) / 2 ;
85+ const auto required_bytes =
86+ GetHybridCacheBytesForBlocks (mid, kv_block_bytes, checkpoint_slot_bytes, checkpoint_interval_blocks);
87+ if (required_bytes <= budget_bytes) {
88+ lo = mid;
89+ }
90+ else {
91+ hi = mid - 1 ;
92+ }
93+ }
94+ return lo;
95+ }
96+
97+ } // namespace hybrid_prefix_budget
98+
3499SequenceManager::SequenceManager (const ModelParam& model_param,
35100 DataType runtime_dtype,
36101 int cache_block_seq_len,
@@ -116,7 +181,15 @@ SequenceManager::SequenceManager(const ModelParam& model_param,
116181 block::Layout layout{block_config};
117182 // dump(layout);
118183
119- size_t block_size = layout.block_size (cache_layer_num);
184+ size_t block_size = layout.block_size (cache_layer_num);
185+ const bool has_linear_prefix_checkpoints = enable_prefix_caching && num_linear_layers > 0 ;
186+ const size_t linear_prefix_slot_bytes =
187+ has_linear_prefix_checkpoints ?
188+ hybrid_prefix_budget::GetHybridCheckpointSlotBytes (linear_prefix_conv_state_shape,
189+ linear_prefix_conv_state_dtype,
190+ linear_prefix_recurrent_state_shape,
191+ linear_prefix_recurrent_state_dtype) :
192+ 0 ;
120193
121194 if (num_linear_layers > 0 && block_count < 1 .) {
122195 const size_t target_bytes = static_cast <size_t >(free_before * block_count);
@@ -128,13 +201,21 @@ SequenceManager::SequenceManager(const ModelParam& model_param,
128201 << " Please decrease max_batch_size to reduce total linear state size or increase cache_max_entry_count." ;
129202 }
130203 const size_t cache_bytes = target_bytes - live_linear_bytes;
131- block_count = static_cast <double >(cache_bytes) / static_cast <double >(block_size);
204+ block_count =
205+ has_linear_prefix_checkpoints ? static_cast <double >(hybrid_prefix_budget::SolveHybridKvBlockCount (
206+ cache_bytes, block_size, linear_prefix_slot_bytes, linear_prefix_cache_interval_blocks_)) :
207+ static_cast <double >(cache_bytes) / static_cast <double >(block_size);
132208 }
133209 else if (num_linear_layers > 0 && block_count >= 1 .) {
134- const size_t requested_blocks = static_cast <size_t >(block_count);
135- const size_t requested_cache_bytes = requested_blocks * block_size;
136- const size_t available_after_live = get_free_size ();
137- TM_CHECK_LE (requested_cache_bytes, available_after_live) << " Insufficient memory for KV cache blocks." ;
210+ const size_t requested_blocks = static_cast <size_t >(block_count);
211+ const auto requested_cache_bytes =
212+ has_linear_prefix_checkpoints ? hybrid_prefix_budget::GetHybridCacheBytesForBlocks (
213+ requested_blocks, block_size, linear_prefix_slot_bytes, linear_prefix_cache_interval_blocks_) :
214+ static_cast <unsigned __int128>(requested_blocks) * block_size;
215+ const size_t available_after_live = get_free_size ();
216+ TM_CHECK (requested_cache_bytes <= static_cast <unsigned __int128>(available_after_live))
217+ << " Insufficient memory for "
218+ << (has_linear_prefix_checkpoints ? " hybrid prefix cache blocks and checkpoints." : " KV cache blocks." );
138219 }
139220
140221 block_manager_ = std::make_shared<BlockManager>(block_size, block_count, chunk_size, allocator, get_free_size);
0 commit comments