Skip to content

Commit 8a385d5

Browse files
authored
fix: use device_map for HF model loading (#581) (#587)
1 parent ac6a4cf commit 8a385d5

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

mellea/backends/huggingface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ def __init__(
270270
)
271271
# Get the model and tokenizer.
272272
self._model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
273-
self._hf_model_id
274-
).to(self._device) # type: ignore
273+
self._hf_model_id, device_map=str(self._device)
274+
)
275275
self._tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
276276
self._hf_model_id
277277
)

0 commit comments

Comments
 (0)