File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -133,6 +133,13 @@ def train_model(
133133 model_base = AutoModelForCausalLM .from_pretrained (
134134 base_model , device_map = device_map , use_cache = False
135135 )
136+
137+ # `fp16=True` enables CUDA-specific mixed precision via GradScaler, which doesn't function properly on cpu or mps.
138+ # Check all the model's parameters to ensure it's okay to use.
139+ use_fp16 = all (
140+ param .device .type != "cpu" and param .device .type != "mps"
141+ for param in model_base .parameters ()
142+ )
136143 except NotImplementedError as e :
137144 if "meta tensor" in str (e ):
138145 raise RuntimeError (
@@ -176,7 +183,7 @@ def train_model(
176183 max_seq_length = max_length ,
177184 per_device_train_batch_size = batch_size ,
178185 gradient_accumulation_steps = grad_accum ,
179- fp16 = True ,
186+ fp16 = use_fp16 ,
180187 )
181188
182189 trainer = SafeSaveTrainer (
@@ -210,7 +217,7 @@ def train_model(
210217 max_seq_length = max_length ,
211218 per_device_train_batch_size = batch_size ,
212219 gradient_accumulation_steps = grad_accum ,
213- fp16 = True ,
220+ fp16 = use_fp16 ,
214221 )
215222
216223 trainer = SafeSaveTrainer (
Original file line number Diff line number Diff line change @@ -60,6 +60,9 @@ def get(self, key: str | int) -> Any | None:
6060
6161 def put (self , key : str | int , value : Any ):
6262 """Put a value into the cache."""
63+ if self .capacity == 0 :
64+ return
65+
6366 if key in self .cache :
6467 # If the key exists, move it to the end (most recent)
6568 self .cache .pop (key )
Original file line number Diff line number Diff line change @@ -79,7 +79,7 @@ class ModelIdentifier:
7979IBM_GRANITE_4_MICRO_3B = ModelIdentifier (
8080 hf_model_name = "ibm-granite/granite-4.0-micro" ,
8181 ollama_name = "granite4:micro" ,
82- watsonx_name = "ibm/granite-4-small" ,
82+ watsonx_name = "ibm/granite-4-h- small" , # Keeping hybrid version here for backwards compatibility.
8383)
8484
8585# Granite 3.3 Vision Model (2B)
Original file line number Diff line number Diff line change @@ -28,6 +28,9 @@ def test_alora_config_creation():
2828 mock_tokenizer_class .from_pretrained .return_value = mock_tokenizer
2929
3030 mock_model = Mock ()
31+ mock_param = Mock ()
32+ mock_param .device .type = "cuda"
33+ mock_model .parameters .return_value = [mock_param ]
3134 mock_model_class .from_pretrained .return_value = mock_model
3235
3336 mock_peft_model = Mock ()
@@ -102,6 +105,9 @@ def test_lora_config_creation():
102105 mock_tokenizer_class .from_pretrained .return_value = mock_tokenizer
103106
104107 mock_model = Mock ()
108+ mock_param = Mock ()
109+ mock_param .device .type = "cuda"
110+ mock_model .parameters .return_value = [mock_param ]
105111 mock_model_class .from_pretrained .return_value = mock_model
106112
107113 mock_peft_model = Mock ()
@@ -175,7 +181,11 @@ def test_invocation_prompt_tokenization():
175181 mock_tokenizer_class .from_pretrained .return_value = mock_tokenizer
176182
177183 # Setup other mocks
178- mock_model_class .from_pretrained .return_value = Mock ()
184+ mock_model = Mock ()
185+ mock_param = Mock ()
186+ mock_param .device .type = "cuda"
187+ mock_model .parameters .return_value = [mock_param ]
188+ mock_model_class .from_pretrained .return_value = mock_model
179189 mock_get_peft_model .return_value = Mock ()
180190
181191 mock_ds = MagicMock ()
You can’t perform that action at this time.
0 commit comments