We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent ac6a4cf commit 8a385d5Copy full SHA for 8a385d5
1 file changed
mellea/backends/huggingface.py
@@ -270,8 +270,8 @@ def __init__(
270
)
271
# Get the model and tokenizer.
272
self._model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
273
- self._hf_model_id
274
- ).to(self._device) # type: ignore
+ self._hf_model_id, device_map=str(self._device)
+ )
275
self._tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
276
self._hf_model_id
277
0 commit comments