Skip to content

Commit 839eead

Browse files
authored
fix: flush MPS cache in alora test GPU cleanup (#790) (#800)
* fix: flush MPS cache in alora test GPU cleanup (#790) Add torch.mps.empty_cache() after CUDA cleanup blocks in both alora integration tests, matching the existing conftest pattern. Prevents MPS memory from accumulating between tests on Apple Silicon. * refactor(test): extract flush_device_caches() helper for GPU cleanup Consolidate the duplicated gc.collect + CUDA/MPS cache flush pattern into a single flush_device_caches() function in test/conftest.py. - Replaces 4 inline flush sites with a single call - Adds MPS support to sites that previously only handled CUDA (pytest_runtest_setup backend transitions, memory_cleaner fixture) - Fixes a bug where gc.collect() was conditional on CUDA availability in pytest_runtest_setup (now runs unconditionally) - Adds torch.mps.synchronize() for parity with CUDA synchronize() - Enriches cleanup_gpu_backend() VRAM logging: device-aware reporting for both CUDA (free/total/allocated/reserved/fragmentation) and MPS (allocated/max), with reclaimed bytes on both paths - Removes unused shutil/sys imports from test_alora_train_integration
1 parent e0ffd3d commit 839eead

2 files changed

Lines changed: 70 additions & 58 deletions

File tree

test/cli/test_alora_train_integration.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55

66
import json
77
import os
8-
import shutil
9-
import sys
108
import tempfile
119
from pathlib import Path
1210

1311
import pytest
1412

13+
from test.conftest import flush_device_caches
14+
1515
torch = pytest.importorskip("torch", reason="torch not installed — install mellea[hf]")
1616
from transformers import AutoTokenizer
1717

@@ -292,8 +292,6 @@ def test_alora_training_integration():
292292
)
293293

294294
# Cleanup GPU memory
295-
import gc
296-
297295
# 1. Remove accelerate dispatch hooks before moving to CPU.
298296
# device_map="auto" installs hooks that prevent full VRAM release otherwise.
299297
try:
@@ -310,12 +308,8 @@ def test_alora_training_integration():
310308
base_model.cpu()
311309
del base_model
312310

313-
# 4. Force GC and flush CUDA cache synchronously.
314-
gc.collect()
315-
gc.collect()
316-
if torch.cuda.is_available():
317-
torch.cuda.empty_cache()
318-
torch.cuda.synchronize()
311+
# 4. Flush device caches.
312+
flush_device_caches()
319313

320314

321315
def test_lora_training_integration():
@@ -391,10 +385,4 @@ def test_lora_training_integration():
391385
)
392386

393387
# Cleanup GPU memory after training
394-
import gc
395-
396-
gc.collect()
397-
gc.collect()
398-
if torch.cuda.is_available():
399-
torch.cuda.empty_cache()
400-
torch.cuda.synchronize()
388+
flush_device_caches()

test/conftest.py

Lines changed: 65 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,33 @@ def pytest_configure(config):
260260
# ============================================================================
261261

262262

263+
# ============================================================================
264+
# Device Cache Flush Helper
265+
# ============================================================================
266+
267+
268+
def flush_device_caches() -> None:
269+
"""Force garbage collection and flush GPU device caches (CUDA and MPS).
270+
271+
Safe to call unconditionally — skips gracefully when torch is absent
272+
or no accelerator is available.
273+
"""
274+
gc.collect()
275+
gc.collect()
276+
277+
try:
278+
import torch
279+
280+
if torch.cuda.is_available():
281+
torch.cuda.empty_cache()
282+
torch.cuda.synchronize()
283+
if torch.backends.mps.is_available():
284+
torch.mps.synchronize()
285+
torch.mps.empty_cache()
286+
except ImportError:
287+
pass
288+
289+
263290
# ============================================================================
264291
# vLLM Backend Cleanup Helper
265292
# ============================================================================
@@ -275,22 +302,34 @@ def cleanup_gpu_backend(backend, backend_name="unknown"):
275302
backend: The backend instance to clean up.
276303
backend_name: Name for logging.
277304
"""
278-
import gc
279-
280305
logger = FancyLogger.get_logger()
281306
logger.info(f"Cleaning up {backend_name} backend GPU memory...")
282307

283308
try:
284309
import torch
285310

311+
# Snapshot memory before cleanup for reporting
312+
free_before = 0
313+
allocated_before = 0
286314
if torch.cuda.is_available():
287-
free_before, total = torch.cuda.mem_get_info()
315+
free_before, total_mem = torch.cuda.mem_get_info()
316+
reserved = torch.cuda.memory_reserved()
317+
allocated = torch.cuda.memory_allocated()
288318
logger.info(
289-
f" GPU before cleanup: {free_before / 1024**3:.1f}GB free "
290-
f"/ {total / 1024**3:.1f}GB total"
319+
f" CUDA before cleanup: {free_before / 1024**3:.1f}GB free "
320+
f"/ {total_mem / 1024**3:.1f}GB total "
321+
f"(allocated {allocated / 1024**2:.0f}MB, "
322+
f"reserved {reserved / 1024**2:.0f}MB, "
323+
f"fragmentation {(reserved - allocated) / 1024**2:.0f}MB)"
324+
)
325+
elif torch.backends.mps.is_available():
326+
allocated_before = torch.mps.current_allocated_memory()
327+
max_mem = torch.mps.recommended_max_memory()
328+
logger.info(
329+
f" MPS before cleanup: "
330+
f"allocated {allocated_before / 1024**2:.0f}MB "
331+
f"/ {max_mem / 1024**3:.1f}GB max"
291332
)
292-
else:
293-
free_before = 0
294333

295334
# 1. Clear the LRU cache (holds DynamicCache KV tensors on GPU)
296335
if hasattr(backend, "_cache") and hasattr(backend._cache, "cache"):
@@ -357,21 +396,27 @@ def cleanup_gpu_backend(backend, backend_name="unknown"):
357396
del backend._tokenizer
358397

359398
# 7. Force garbage collection and flush device caches
360-
gc.collect()
361-
gc.collect()
399+
flush_device_caches()
362400

401+
# Report memory after cleanup
363402
if torch.cuda.is_available():
364-
torch.cuda.empty_cache()
365-
torch.cuda.synchronize()
366-
367-
free_after, total = torch.cuda.mem_get_info()
403+
free_after, total_mem = torch.cuda.mem_get_info()
404+
reserved = torch.cuda.memory_reserved()
405+
allocated = torch.cuda.memory_allocated()
368406
logger.info(
369-
f" GPU after cleanup: {free_after / 1024**3:.1f}GB free "
370-
f"/ {total / 1024**3:.1f}GB total "
371-
f"(reclaimed {(free_after - free_before) / 1024**3:.1f}GB)"
407+
f" CUDA after cleanup: {free_after / 1024**3:.1f}GB free "
408+
f"/ {total_mem / 1024**3:.1f}GB total "
409+
f"(allocated {allocated / 1024**2:.0f}MB, "
410+
f"reserved {reserved / 1024**2:.0f}MB, "
411+
f"reclaimed {(free_after - free_before) / 1024**3:.1f}GB)"
412+
)
413+
elif torch.backends.mps.is_available():
414+
allocated_after = torch.mps.current_allocated_memory()
415+
logger.info(
416+
f" MPS after cleanup: "
417+
f"allocated {allocated_after / 1024**2:.0f}MB "
418+
f"(reclaimed {(allocated_before - allocated_after) / 1024**2:.0f}MB)"
372419
)
373-
if torch.backends.mps.is_available():
374-
torch.mps.empty_cache()
375420

376421
except ImportError:
377422
pass
@@ -478,17 +523,7 @@ def pytest_runtest_setup(item):
478523
"Running GPU cleanup."
479524
)
480525

481-
# General GPU flush for any transition
482-
try:
483-
import torch
484-
485-
if torch.cuda.is_available():
486-
gc.collect()
487-
gc.collect()
488-
torch.cuda.empty_cache()
489-
torch.cuda.synchronize()
490-
except ImportError:
491-
pass
526+
flush_device_caches()
492527

493528
# Warm up Ollama models when entering Ollama group
494529
if current_group == "ollama" and prev_group != "ollama":
@@ -566,18 +601,7 @@ def pytest_runtest_teardown(item, nextitem):
566601
def memory_cleaner():
567602
"""Lightweight memory cleanup — safety net for per-test GPU leaks."""
568603
yield
569-
570-
gc.collect()
571-
gc.collect()
572-
573-
try:
574-
import torch
575-
576-
if torch.cuda.is_available():
577-
torch.cuda.empty_cache()
578-
torch.cuda.synchronize()
579-
except ImportError:
580-
pass
604+
flush_device_caches()
581605

582606

583607
def evict_ollama_models() -> None:

0 commit comments

Comments
 (0)