Skip to content

Commit dc94364

Browse files
authored
fix: nonhybrid granite model id (#546)
* fix: granite for non-hybrid model id * feat: granite4 nano model family (1b, 350m) * fix: skip alora training test on CICD
1 parent 8316495 commit dc94364

2 files changed

Lines changed: 27 additions & 7 deletions

File tree

mellea/backends/model_ids.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,18 @@ class ModelIdentifier:
4545
watsonx_name="ibm/granite-4-h-small",
4646
)
4747

48+
IBM_GRANITE_4_HYBRID_1B = ModelIdentifier(
49+
hf_model_name="ibm-granite/granite-4.0-h-1b",
50+
ollama_name="granite4:1b-h",
51+
watsonx_name=None,
52+
)
53+
54+
IBM_GRANITE_4_HYBRID_350m = ModelIdentifier(
55+
hf_model_name="ibm-granite/granite-4.0-h-350m",
56+
ollama_name="granite4:350m-h",
57+
watsonx_name=None,
58+
)
59+
4860

4961
# Deprecated Granite 3 models - kept for backward compatibility
5062
# These maintain their original model references (not upgraded to Granite 4)
@@ -65,9 +77,9 @@ class ModelIdentifier:
6577
# - Ollama/HF: Uses MICRO (fits in CI memory constraints)
6678
# - Watsonx: Uses SMALL (required for watsonx support)
6779
IBM_GRANITE_4_MICRO_3B = ModelIdentifier(
68-
hf_model_name="ibm-granite/granite-4.0-h-micro",
69-
ollama_name="granite4:micro-h",
70-
watsonx_name="ibm/granite-4-h-small",
80+
hf_model_name="ibm-granite/granite-4.0-micro",
81+
ollama_name="granite4:micro",
82+
watsonx_name="ibm/granite-4-small",
7183
)
7284

7385
# Granite 3.3 Vision Model (2B)

test/cli/test_alora_train_integration.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,24 @@
1414
import torch
1515
from transformers import AutoTokenizer
1616

17+
pytestmark = [
18+
pytest.mark.huggingface,
19+
pytest.mark.llm,
20+
pytest.mark.requires_gpu,
21+
pytest.mark.requires_heavy_ram,
22+
# Skip entire module in CI since 17/18 tests are qualitative
23+
pytest.mark.skipif(
24+
int(os.environ.get("CICD", 0)) == 1,
25+
reason="Skipping alora training tests in CI - need gpus",
26+
),
27+
]
28+
1729
# Check if MPS is available but PyTorch version is too old
1830
_mps_needs_cpu_fallback = torch.backends.mps.is_available() and tuple(
1931
int(x) for x in torch.__version__.split(".")[:2]
2032
) < (2, 8)
2133

2234

23-
@pytest.mark.huggingface
24-
@pytest.mark.llm
2535
def test_alora_training_integration():
2636
"""Integration test: Train a tiny aLoRA adapter and verify it works.
2737
@@ -278,8 +288,6 @@ def test_alora_training_integration():
278288
)
279289

280290

281-
@pytest.mark.huggingface
282-
@pytest.mark.llm
283291
def test_lora_training_integration():
284292
"""Integration test: Train a tiny standard LoRA adapter and verify it works.
285293

0 commit comments

Comments
 (0)