Skip to content

Commit 2f74853

Browse files
fix: huggingface memory leak (#544)
* adding logic to cleanup on evict in cache * implement cleanup logic for KV cache eviction to free GPU memory * adding uuid for cache key * reverting changes to precommit * adding scores to cache and removing it from the constructor * adding more robust type checking to run hf tests * suppressing code cov output to stdout to make tests more readable * suppressing code cov output to stdout to make tests more readable * small fix * removing cov-report from subprocesses * setting lru cache to 0 for now; till we figure out block_attention and dynamic LRU size * removing return_scores from docs --------- Co-authored-by: Nathan Fulton <nathan@ibm.com>
1 parent 5ac4b2f commit 2f74853

4 files changed

Lines changed: 114 additions & 21 deletions

File tree

mellea/backends/cache.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import abc
44
from collections import OrderedDict
5+
from collections.abc import Callable
56
from typing import Any
67

78

@@ -11,12 +12,12 @@ class Cache(abc.ABC):
1112
# Whenever PEP 695 generics are supported by mypy, we should use them here.
1213

1314
@abc.abstractmethod
14-
def put(self, key: str, value: Any):
15+
def put(self, key: str | int, value: Any):
1516
"""Inserts into the cache. May result in eviction of other cached values."""
1617
...
1718

1819
@abc.abstractmethod
19-
def get(self, key: str) -> Any | None:
20+
def get(self, key: str | int) -> Any | None:
2021
"""Retrieves a value from the cache. Returns `None` if the `id` has no cached value. May impact which cache values are evicted."""
2122
...
2223

@@ -29,19 +30,25 @@ def current_size(self) -> int:
2930
class SimpleLRUCache(Cache):
3031
"""A simple [LRU](https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_Recently_Used_(LRU)) cache."""
3132

32-
def __init__(self, capacity: int):
33+
def __init__(self, capacity: int, on_evict: Callable[[Any], None] | None = None):
3334
"""Initializes the LRU cache with a certain capacity.
3435
3536
The `SimpleLRUCache` either contains a value or it doesn't. There is no cache hierarchy. Take care when choosing `capacity`. In practice usually a small value will be fine, but ideally you should try to choose a capacity based upon your available device memory and the context size of your model.
37+
38+
Args:
39+
capacity: Maximum number of items to store in the cache.
40+
on_evict: Optional callback function called when an item is evicted from the cache.
41+
This can be used to free resources (e.g., GPU memory) when items are removed.
3642
"""
3743
self.capacity = capacity
3844
self.cache: OrderedDict = OrderedDict()
45+
self.on_evict = on_evict
3946

4047
def current_size(self):
4148
"""Just return the size of the key set. This isn't necessarily safe."""
4249
return len(self.cache.keys())
4350

44-
def get(self, key: str) -> Any | None:
51+
def get(self, key: str | int) -> Any | None:
4552
"""Gets a value from the cache."""
4653
if key not in self.cache:
4754
return None
@@ -51,13 +58,16 @@ def get(self, key: str) -> Any | None:
5158
self.cache[key] = value
5259
return value
5360

54-
def put(self, key: str, value: Any):
61+
def put(self, key: str | int, value: Any):
5562
"""Put a value into the cache."""
5663
if key in self.cache:
5764
# If the key exists, move it to the end (most recent)
5865
self.cache.pop(key)
5966
elif len(self.cache) >= self.capacity:
6067
# If the cache is full, remove the least recently used item
61-
self.cache.popitem(last=False)
68+
_evicted_key, evicted_value = self.cache.popitem(last=False)
69+
# Call eviction callback if provided (e.g., to free GPU memory)
70+
if self.on_evict is not None:
71+
self.on_evict(evicted_value)
6272
# Add the new key-value pair to the end (most recent)
6373
self.cache[key] = value

mellea/backends/huggingface.py

Lines changed: 92 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,55 @@
8282
class HFAloraCacheInfo:
8383
"""A dataclass for holding some KV cache and associated information."""
8484

85-
kv_cache: DynamicCache
85+
kv_cache: DynamicCache | None
8686
merged_token_ids: Any
8787
merged_attention: Any
8888
q_end: int = -1
89+
scores: Any = None
90+
91+
92+
def _cleanup_kv_cache(cache_info: HFAloraCacheInfo) -> None:
93+
"""Free GPU memory when KV cache is evicted from LRU.
94+
95+
This function is called by SimpleLRUCache when an entry is evicted.
96+
It explicitly deletes tensor references and calls torch.cuda.empty_cache()
97+
to return pooled CUDA memory to the device.
98+
99+
Args:
100+
cache_info: The HFAloraCacheInfo being evicted from cache.
101+
"""
102+
import gc
103+
104+
if cache_info is None:
105+
return
106+
107+
kv = cache_info.kv_cache
108+
if kv is not None:
109+
# Delete individual tensors from each layer
110+
if hasattr(kv, "key_cache"):
111+
for tensor in kv.key_cache:
112+
del tensor
113+
kv.key_cache.clear()
114+
if hasattr(kv, "value_cache"):
115+
for tensor in kv.value_cache:
116+
del tensor
117+
kv.value_cache.clear()
118+
del cache_info.kv_cache
119+
120+
# Delete other tensors
121+
if cache_info.merged_attention is not None:
122+
del cache_info.merged_attention
123+
124+
# Delete score tensors if present
125+
if cache_info.scores is not None:
126+
for tensor in cache_info.scores:
127+
del tensor
128+
del cache_info.scores
129+
130+
# Force Python garbage collection and return CUDA memory to device
131+
gc.collect()
132+
if torch.cuda.is_available():
133+
torch.cuda.empty_cache()
89134

90135

91136
# modified from VLLM v0.9.2 code base
@@ -244,7 +289,11 @@ def __init__(
244289
), "vocab size mismatch between llguidance and huggingface tokenizers ... wtf?"
245290

246291
self._use_caches = use_caches
247-
self._cache = cache if cache is not None else SimpleLRUCache(3)
292+
self._cache = (
293+
cache
294+
if cache is not None
295+
else SimpleLRUCache(0, on_evict=_cleanup_kv_cache)
296+
)
248297

249298
# Adapters can be made known to the backend (added) and loaded.
250299
self._added_adapters: dict[str, LocalHFAdapter] = {}
@@ -877,7 +926,7 @@ async def _generate_from_context_standard(
877926
# Passed as args/kwargs to generate.
878927
input_ids,
879928
return_dict_in_generate=True,
880-
output_scores=True,
929+
use_cache=self._use_caches, # Only create KV cache if caching is enabled
881930
**self._make_backend_specific_and_remove(generate_options),
882931
**streaming_kwargs, # type: ignore
883932
**format_kwargs, # type: ignore
@@ -941,7 +990,7 @@ async def processing(
941990
# and already decoded.
942991
if isinstance(chunk, str):
943992
mot._underlying_value += chunk
944-
else:
993+
elif isinstance(chunk, GenerateDecoderOnlyOutput):
945994
# Otherwise, it's a non-streaming request. Decode it here.
946995
mot._meta["hf_output"] = chunk
947996
mot._underlying_value += self._tokenizer.decode(
@@ -968,19 +1017,31 @@ async def post_processing(
9681017
# The ModelOutputThunk must be computed by this point.
9691018
assert mot.value is not None
9701019

971-
# Add an entry to the cache for ALora reuse.
972-
if self._use_caches and mot._meta.get("hf_output", None) is not None:
973-
output_complete = mot._meta["hf_output"].sequences[0]
974-
cache: DynamicCache = mot._meta["hf_output"].past_key_values # type: ignore
1020+
# Store KV cache in LRU separately (not in mot._meta) to enable proper cleanup on eviction.
1021+
# This prevents GPU memory from being held by ModelOutputThunk references.
1022+
hf_output = mot._meta.get("hf_output", None)
1023+
if (
1024+
self._use_caches
1025+
and isinstance(hf_output, GenerateDecoderOnlyOutput)
1026+
and (hf_output.past_key_values is not None or hf_output.scores is not None)
1027+
):
1028+
output_complete = hf_output.sequences[0]
1029+
kv_cache: DynamicCache | None = hf_output.past_key_values # type: ignore
9751030

9761031
cache_info = HFAloraCacheInfo(
977-
kv_cache=cache,
1032+
kv_cache=kv_cache,
9781033
merged_token_ids=output_complete,
9791034
merged_attention=torch.ones_like(output_complete).to(self._device),
9801035
q_end=len(input_ids[0]), # type: ignore
1036+
scores=hf_output.scores,
9811037
)
9821038

983-
self.cache_put(mot.value, cache_info)
1039+
cache_key = id(mot.value)
1040+
self.cache_put(cache_key, cache_info)
1041+
1042+
# Clear KV cache and scores from HF output - they're now owned by the LRU cache
1043+
hf_output.past_key_values = None
1044+
hf_output.scores = None
9841045

9851046
# Only scan for tools if we are not doing structured output and tool calls were provided to the model.
9861047
if _format is None and tool_calls:
@@ -1002,14 +1063,32 @@ async def post_processing(
10021063
# HuggingFace local models don't typically provide token counts
10031064
# but we can record response metadata if available
10041065
hf_output = mot._meta.get("hf_output")
1005-
if hf_output is not None:
1066+
if isinstance(hf_output, GenerateDecoderOnlyOutput):
10061067
record_response_metadata(span, hf_output)
10071068

10081069
# Close the span now that async operation is complete
10091070
end_backend_span(span)
10101071
# Clean up span reference
10111072
del mot._meta["_telemetry_span"]
10121073

1074+
# When caching is disabled, clear hf_output from meta to free GPU memory.
1075+
# The sequences tensor is on GPU and accumulates if not cleared.
1076+
if not self._use_caches and isinstance(
1077+
mot._meta.get("hf_output"), GenerateDecoderOnlyOutput
1078+
):
1079+
import gc
1080+
1081+
hf_out = mot._meta["hf_output"]
1082+
if hasattr(hf_out, "sequences") and hf_out.sequences is not None:
1083+
del hf_out.sequences
1084+
if hasattr(hf_out, "scores") and hf_out.scores is not None:
1085+
del hf_out.scores
1086+
del mot._meta["hf_output"]
1087+
1088+
# Force Python GC and return CUDA memory to device
1089+
gc.collect()
1090+
torch.cuda.empty_cache()
1091+
10131092
# Generate the log for this ModelOutputThunk.
10141093
generate_log = GenerateLog()
10151094
generate_log.prompt = conversation
@@ -1159,13 +1238,13 @@ async def generate_from_raw(
11591238
return results
11601239

11611240
# region cache management
1162-
def cache_get(self, id: str) -> HFAloraCacheInfo | None:
1241+
def cache_get(self, id: str | int) -> HFAloraCacheInfo | None:
11631242
"""Retrieve from cache."""
11641243
v = self._cache.get(id)
11651244
assert v is None or type(v) is HFAloraCacheInfo
11661245
return v
11671246

1168-
def cache_put(self, id: str, v: HFAloraCacheInfo):
1247+
def cache_put(self, id: str | int, v: HFAloraCacheInfo):
11691248
"""Put into cache."""
11701249
self._cache.put(id, v)
11711250

test/backends/test_huggingface.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,11 @@ async def test_error_during_generate_with_lock(backend) -> None:
510510
b: LocalHFBackend = copy(backend)
511511
model = copy(b._model)
512512
b._model = model
513-
b._model.set_adapter([])
513+
try:
514+
b._model.set_adapter([])
515+
except ValueError as e:
516+
if "No adapter loaded" not in str(e):
517+
raise
514518
b._added_adapters = {}
515519
b._loaded_adapters = {}
516520
b.add_adapter(

test/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def _run_heavy_modules_isolated(session, heavy_modules: list[str]) -> int:
244244
print("-" * 70)
245245

246246
# Build pytest command with same options as parent session
247-
cmd = [sys.executable, "-m", "pytest", module_path, "-v"]
247+
cmd = [sys.executable, "-m", "pytest", module_path, "-v", "--no-cov"]
248248

249249
# Add markers from original command if present
250250
config = session.config

0 commit comments

Comments
 (0)