@@ -1141,17 +1141,20 @@ def issue_write_back_storage_task(self, task: WriteStorageTask, is_sync=True):
11411141 if self .kvcache_storage_backend is None :
11421142 return
11431143
1144- if len (task .keys ) != len (task .gpu_block_ids ):
1145- err_msg = (
1146- f"write_back_storage error: hash_keys({ len (task .keys )} ) != gpu_block_ids({ len (task .gpu_block_ids )} )"
1147- )
1148- logger .error (err_msg )
1149- raise ValueError (err_msg )
1150-
1151- self .task_write_back_event [task .task_id ] = Event ()
1152- self .cache_task_queue .put_transfer_task ((CacheStatus .GPU2STORAGE , task ))
1153- if is_sync :
1154- self .wait_write_storage_task (task .task_id )
1144+ assert is_sync , "Only support is_sync=True for now."
1145+ self ._acquire_kvcache_lock ()
1146+ try :
1147+ if len (task .keys ) != len (task .gpu_block_ids ):
1148+ err_msg = f"write_back_storage error: hash_keys({ len (task .keys )} ) != gpu_block_ids({ len (task .gpu_block_ids )} )"
1149+ logger .error (err_msg )
1150+ raise ValueError (err_msg )
1151+
1152+ self .task_write_back_event [task .task_id ] = Event ()
1153+ self .cache_task_queue .put_transfer_task ((CacheStatus .GPU2STORAGE , task ))
1154+ if is_sync :
1155+ self .wait_write_storage_task (task .task_id )
1156+ finally :
1157+ self ._release_kvcache_lock ()
11551158
11561159 def wait_write_storage_task (self , req_id ):
11571160 """
@@ -1168,12 +1171,18 @@ def issue_prefetch_storage_task(self, task: ReadStorageTask, is_sync=True):
11681171 if self .kvcache_storage_backend is None :
11691172 return []
11701173
1171- storage_block_ids = []
1172- self .task_prefetch_event [task .task_id ] = Event ()
1173- # issue task to cache_transfer_manager
1174- self .cache_task_queue .put_transfer_task ((CacheStatus .STORAGE2GPU , task ))
1175- if is_sync :
1176- storage_block_ids = self .wait_prefetch_storage_task (task .task_id )
1174+ assert is_sync , "Only support is_sync=True for now."
1175+ self ._acquire_kvcache_lock ()
1176+
1177+ try :
1178+ storage_block_ids = []
1179+ self .task_prefetch_event [task .task_id ] = Event ()
1180+ # issue task to cache_transfer_manager
1181+ self .cache_task_queue .put_transfer_task ((CacheStatus .STORAGE2GPU , task ))
1182+ if is_sync :
1183+ storage_block_ids = self .wait_prefetch_storage_task (task .task_id )
1184+ finally :
1185+ self ._release_kvcache_lock ()
11771186 return storage_block_ids
11781187
11791188 def wait_prefetch_storage_task (self , req_id ):
0 commit comments