-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathtest_model_to_hf.py
More file actions
40 lines (32 loc) · 865 Bytes
/
test_model_to_hf.py
File metadata and controls
40 lines (32 loc) · 865 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
from itertools import chain
import torch
from datasets import load_dataset
from transformers import (
AutoTokenizer,
GemmaConfig,
GemmaForCausalLM,
Trainer,
TrainingArguments,
set_seed,
default_data_collator,
)
set_seed(42)
print("Torch Version:", torch.__version__)
config = GemmaConfig.from_pretrained(
"google/gemma-2b",
attn_implementation="eager",
)
# config.max_position_embeddings = 128
# config.use_cache = False
config.segment_size = config.max_position_embeddings # Add config
print(config)
pretrained_model = GemmaForCausalLM.from_pretrained(
"google/gemma-2b", torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
"google/gemma-2b",
)
pretrained_model.save_pretrained('./models/gemma-2b')
config.save_pretrained('./models/gemma-2b')
tokenizer.save_pretrained('./models/gemma-2b')