Skip to content

Commit ff00e89

Browse files
authored
fix: issues found in comprehensive tests: cache capacity, watsonx (#560)
model name, and alora training on macs
1 parent 0cf5d37 commit ff00e89

4 files changed

Lines changed: 24 additions & 4 deletions

File tree

cli/alora/train.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ def train_model(
133133
model_base = AutoModelForCausalLM.from_pretrained(
134134
base_model, device_map=device_map, use_cache=False
135135
)
136+
137+
# `fp16=True` enables CUDA-specific mixed precision via GradScaler, which doesn't function properly on cpu or mps.
138+
# Check all the model's parameters to ensure it's okay to use.
139+
use_fp16 = all(
140+
param.device.type != "cpu" and param.device.type != "mps"
141+
for param in model_base.parameters()
142+
)
136143
except NotImplementedError as e:
137144
if "meta tensor" in str(e):
138145
raise RuntimeError(
@@ -176,7 +183,7 @@ def train_model(
176183
max_seq_length=max_length,
177184
per_device_train_batch_size=batch_size,
178185
gradient_accumulation_steps=grad_accum,
179-
fp16=True,
186+
fp16=use_fp16,
180187
)
181188

182189
trainer = SafeSaveTrainer(
@@ -210,7 +217,7 @@ def train_model(
210217
max_seq_length=max_length,
211218
per_device_train_batch_size=batch_size,
212219
gradient_accumulation_steps=grad_accum,
213-
fp16=True,
220+
fp16=use_fp16,
214221
)
215222

216223
trainer = SafeSaveTrainer(

mellea/backends/cache.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def get(self, key: str | int) -> Any | None:
6060

6161
def put(self, key: str | int, value: Any):
6262
"""Put a value into the cache."""
63+
if self.capacity == 0:
64+
return
65+
6366
if key in self.cache:
6467
# If the key exists, move it to the end (most recent)
6568
self.cache.pop(key)

mellea/backends/model_ids.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class ModelIdentifier:
7979
IBM_GRANITE_4_MICRO_3B = ModelIdentifier(
8080
hf_model_name="ibm-granite/granite-4.0-micro",
8181
ollama_name="granite4:micro",
82-
watsonx_name="ibm/granite-4-small",
82+
watsonx_name="ibm/granite-4-h-small", # Keeping hybrid version here for backwards compatibility.
8383
)
8484

8585
# Granite 3.3 Vision Model (2B)

test/cli/test_alora_train.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ def test_alora_config_creation():
2828
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
2929

3030
mock_model = Mock()
31+
mock_param = Mock()
32+
mock_param.device.type = "cuda"
33+
mock_model.parameters.return_value = [mock_param]
3134
mock_model_class.from_pretrained.return_value = mock_model
3235

3336
mock_peft_model = Mock()
@@ -102,6 +105,9 @@ def test_lora_config_creation():
102105
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
103106

104107
mock_model = Mock()
108+
mock_param = Mock()
109+
mock_param.device.type = "cuda"
110+
mock_model.parameters.return_value = [mock_param]
105111
mock_model_class.from_pretrained.return_value = mock_model
106112

107113
mock_peft_model = Mock()
@@ -175,7 +181,11 @@ def test_invocation_prompt_tokenization():
175181
mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
176182

177183
# Setup other mocks
178-
mock_model_class.from_pretrained.return_value = Mock()
184+
mock_model = Mock()
185+
mock_param = Mock()
186+
mock_param.device.type = "cuda"
187+
mock_model.parameters.return_value = [mock_param]
188+
mock_model_class.from_pretrained.return_value = mock_model
179189
mock_get_peft_model.return_value = Mock()
180190

181191
mock_ds = MagicMock()

0 commit comments

Comments
 (0)