Skip to content

Commit 243a161

Browse files
authored
adding some extra vram cleanup to make end to end tests smoother (#765)
1 parent a591d5c commit 243a161

3 files changed

Lines changed: 84 additions & 10 deletions

File tree

test/cli/test_alora_train_integration.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,14 +288,30 @@ def test_alora_training_integration():
288288
)
289289

290290
# Cleanup GPU memory
291-
base_model.cpu()
291+
import gc
292+
293+
# 1. Remove accelerate dispatch hooks before moving to CPU.
294+
# device_map="auto" installs hooks that prevent full VRAM release otherwise.
295+
try:
296+
from accelerate.hooks import remove_hook_from_module
297+
298+
remove_hook_from_module(base_model, recurse=True)
299+
except (ImportError, Exception):
300+
pass
301+
302+
# 2. Delete the PeftModel wrapper first — it holds internal refs to base_model.
292303
del model_with_adapter
304+
305+
# 3. Now move base_model to CPU and delete it.
306+
base_model.cpu()
293307
del base_model
294-
import gc
295308

309+
# 4. Force GC and flush CUDA cache synchronously.
310+
gc.collect()
296311
gc.collect()
297312
if torch.cuda.is_available():
298313
torch.cuda.empty_cache()
314+
torch.cuda.synchronize()
299315

300316

301317
def test_lora_training_integration():
@@ -369,3 +385,12 @@ def test_lora_training_integration():
369385
print(
370386
f"✅ Config format verified: {config.get('peft_type')} without alora_invocation_tokens"
371387
)
388+
389+
# Cleanup GPU memory after training
390+
import gc
391+
392+
gc.collect()
393+
gc.collect()
394+
if torch.cuda.is_available():
395+
torch.cuda.empty_cache()
396+
torch.cuda.synchronize()

test/conftest.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,47 @@ def pytest_runtest_setup(item):
767767
except ImportError:
768768
pass
769769

770+
# Warm up Ollama models when entering Ollama group
771+
if current_group == "ollama" and prev_group != "ollama":
772+
logger = FancyLogger.get_logger()
773+
host_str = os.environ.get("OLLAMA_HOST", "127.0.0.1")
774+
port = os.environ.get("OLLAMA_PORT", "11434")
775+
logger.info(
776+
"Warming up ollama models before ollama group (keep_alive=-1)..."
777+
)
778+
for model in ["granite4:micro", "granite4:micro-h", "granite3.2-vision"]:
779+
try:
780+
requests.post(
781+
f"http://{host_str}:{port}/api/generate",
782+
json={
783+
"model": model,
784+
"prompt": "hi",
785+
"stream": False,
786+
"keep_alive": -1,
787+
},
788+
timeout=120,
789+
)
790+
logger.info(" Warmed up and pinned: %s", model)
791+
except Exception as e:
792+
logger.warning(" Warmup failed for %s: %s", model, e)
793+
794+
# Evict Ollama models when leaving Ollama group
795+
if prev_group == "ollama" and current_group != "ollama":
796+
logger = FancyLogger.get_logger()
797+
host_str = os.environ.get("OLLAMA_HOST", "127.0.0.1")
798+
port = os.environ.get("OLLAMA_PORT", "11434")
799+
logger.info("Evicting ollama models from VRAM after ollama group...")
800+
for model in ["granite4:micro", "granite4:micro-h", "granite3.2-vision"]:
801+
try:
802+
requests.post(
803+
f"http://{host_str}:{port}/api/generate",
804+
json={"model": model, "keep_alive": 0},
805+
timeout=10,
806+
)
807+
logger.info(" Evicted: %s", model)
808+
except Exception as e:
809+
logger.warning(" Eviction failed for %s: %s", model, e)
810+
770811
pytest_runtest_setup._last_backend_group = current_group
771812

772813
# Check for override flags from CLI

test/scripts/run_tests_with_ollama.sh

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ fi
4646
mkdir -p "$LOGDIR"
4747

4848
cleanup() {
49+
if [[ "${OLLAMA_EXTERNAL:-0}" == "1" ]]; then
50+
log "Ollama managed externally (OLLAMA_EXTERNAL=1) — skipping shutdown"
51+
return
52+
fi
4953
log "Shutting down ollama server..."
5054
if [[ -n "${OLLAMA_PID:-}" ]] && kill -0 "$OLLAMA_PID" 2>/dev/null; then
5155
kill "$OLLAMA_PID" 2>/dev/null
@@ -138,14 +142,18 @@ done
138142
log "All models ready."
139143

140144
# --- Warm up models (first load into memory is slow) ---
141-
log "Warming up models..."
142-
for model in "${OLLAMA_MODELS[@]}"; do
143-
log " Warming $model ..."
144-
curl -sf "http://127.0.0.1:${OLLAMA_PORT}/api/generate" \
145-
-d "{\"model\": \"$model\", \"prompt\": \"hi\", \"stream\": false}" \
146-
-o /dev/null --max-time 120 || log " Warning: warmup for $model timed out (will load on first test)"
147-
done
148-
log "Warmup complete."
145+
if [[ "${OLLAMA_SKIP_WARMUP:-0}" == "1" ]]; then
146+
log "Skipping model warmup (OLLAMA_SKIP_WARMUP=1)"
147+
else
148+
log "Warming up models..."
149+
for model in "${OLLAMA_MODELS[@]}"; do
150+
log " Warming $model ..."
151+
curl -sf "http://127.0.0.1:${OLLAMA_PORT}/api/generate" \
152+
-d "{\"model\": \"$model\", \"prompt\": \"hi\", \"stream\": false}" \
153+
-o /dev/null --max-time 120 || log " Warning: warmup for $model timed out (will load on first test)"
154+
done
155+
log "Warmup complete."
156+
fi
149157

150158
# --- Run tests ---
151159
log "Starting pytest..."

0 commit comments

Comments
 (0)