@@ -61,40 +61,41 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
6161# TEXT: some background on LLM benchmarking
6262# Load benchmark dataset and evaluate model
6363dataset = pd .read_csv ("benchmark.csv" )
64- category_accs_1300m , avg_acc_1300m = run_benchmark (model , tokenizer , dataset )
64+ # category_accs_1300m, avg_acc_1300m = run_benchmark(model, tokenizer, dataset)
6565
6666# TEXT: ask them to make a prediction on how accuracy will be affected by different model sizes
6767
6868# Benchmark smaller model
69- model_name_350m = "facebook/opt-350m"
70- model_350m = transformers .AutoModelForCausalLM .from_pretrained (model_name_350m , device_map = "auto" )
71- tokenizer_350m = transformers .AutoTokenizer .from_pretrained (model_name_350m )
69+ # model_name_350m = "facebook/opt-350m"
70+ # model_350m = transformers.AutoModelForCausalLM.from_pretrained(model_name_350m, device_map="auto")
71+ # tokenizer_350m = transformers.AutoTokenizer.from_pretrained(model_name_350m)
7272
73- category_accs_350m , avg_acc_350m = run_benchmark (model_350m , tokenizer_350m , dataset )
73+ # category_accs_350m, avg_acc_350m = run_benchmark(model_350m, tokenizer_350m, dataset)
7474
7575# Benchmark larger model
76- model_name_2700m = "facebook/opt-2.7b"
77- model_2700m = transformers .AutoModelForCausalLM .from_pretrained (model_name_2700m , device_map = "auto" )
78- tokenizer_2700m = transformers .AutoTokenizer .from_pretrained (model_name_2700m )
76+ # model_name_2700m = "facebook/opt-2.7b"
77+ # model_2700m = transformers.AutoModelForCausalLM.from_pretrained(model_name_2700m, device_map="auto")
78+ # tokenizer_2700m = transformers.AutoTokenizer.from_pretrained(model_name_2700m)
7979
80- category_accs_2700m , avg_acc_2700m = run_benchmark (model_2700m , tokenizer_2700m , dataset )
80+ # category_accs_2700m, avg_acc_2700m = run_benchmark(model_2700m, tokenizer_2700m, dataset)
8181
8282# Spider plot
8383
84- benchmark_data = {"350M-Model" : category_accs_350m , "1300M-Model" : category_accs_1300m , "2700M-Model" : category_accs_2700m }
85- make_spider_plot (benchmark_data )
84+ # benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "2700M-Model": category_accs_2700m}
85+ # make_spider_plot(benchmark_data)
8686
8787# Part 2
8888
8989# inspect current model
90- print (model )
90+ # print(model)
9191
9292# new LoRA linear layer class
9393class LoRALinear (nn .Linear ):
9494 def __init__ (
9595 self ,
9696 in_features : int ,
9797 out_features : int ,
98+ pretrained_weight : torch .Tensor ,
9899 r : int = 8 ,
99100 lora_alpha : int = 1 ,
100101 ** kwargs
@@ -105,6 +106,7 @@ def __init__(
105106 self .lora_alpha = lora_alpha
106107
107108 nn .Linear .__init__ (self , in_features , out_features , ** kwargs )
109+ self .weight .data = pretrained_weight
108110
109111 # from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
110112 if r > 0 :
@@ -113,16 +115,6 @@ def __init__(
113115 self .scaling = self .lora_alpha / self .r
114116 # Freezing the pre-trained weight matrix
115117 self .weight .requires_grad = False
116- self .reset_parameters ()
117-
118- def reset_parameters (self ):
119- # from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
120- nn .Linear .reset_parameters (self )
121- if hasattr (self , 'lora_A' ):
122- # initialize B the same way as the default for nn.Linear and A to zero
123- # this is different than what is described in the paper but should not affect performance
124- nn .init .kaiming_uniform_ (self .lora_A , a = math .sqrt (5 ))
125- nn .init .zeros_ (self .lora_B )
126118
127119 def forward (self , x : torch .Tensor ):
128120 if self .r > 0 :
@@ -136,48 +128,36 @@ def forward(self, x: torch.Tensor):
136128def replace_linear_with_lora (module ):
137129 for name , child in module .named_children ():
138130 if isinstance (child , nn .Linear ):
139- setattr (module , name , LoRALinear (child .in_features , child .out_features ))
131+ setattr (module , name , LoRALinear (child .in_features , child .out_features , child .weight ))
132+ break
140133 else :
141134 replace_linear_with_lora (child )
142135
143136replace_linear_with_lora (model )
144137
145138# inspect new model
146- print (model )
139+ # print(model)
147140
148141# load chat dataset
149142dataset_name = "timdettmers/openassistant-guanaco"
150143ft_dataset = load_dataset (dataset_name , split = "train" )
151144
152- # train model
153- log_dir = "/scratch/checkpoints/test-sft/opt1.3b_768/"
145+ # train model (barebones loop)
154146batch_size = 4
155147context_length = 768
156- args = transformers .TrainingArguments (log_dir ,
157- per_device_train_batch_size = batch_size ,
158- logging_first_step = True ,
159- logging_steps = 20 ,
160- save_steps = 100 ,
161- )
162-
163- class PrinterCallback (transformers .TrainerCallback ):
164- def on_log (self , args , state , control , model , logs = None , ** kwargs ):
165- start_text = "### Human: When the weather is sunny, what color is the sky?### Assistant:"
166- generate (start_text , model , tokenizer , num_steps = 200 , until = "###" )
167-
168- trainer = SFTTrainer (
169- model ,
170- args = args ,
171- train_dataset = ft_dataset ,
172- dataset_text_field = "text" ,
173- max_seq_length = context_length ,
174- callbacks = [PrinterCallback ()]
175- )
176- trainer .train ()
148+
149+ model = model .to ("cuda" )
150+ for batch in ft_dataset :
151+ prompt = batch ["text" ]
152+ encoding = tokenizer (prompt )
153+ input_ids = torch .IntTensor (encoding ["input_ids" ]).to ("cuda" ).unsqueeze (0 )
154+ attention_mask = torch .Tensor (encoding ["attention_mask" ]).to ("cuda" ).unsqueeze (0 )
155+ outputs = model (input_ids , attention_mask )
156+
177157
178158# evaluate finetuned model on benchmark
179159category_accs_1300m_ft , avg_acc_1300m_ft = run_benchmark (model , tokenizer , dataset )
180160
181161# add to spider plot
182- benchmark_data = {"350M-Model" : category_accs_350m , "1300M-Model" : category_accs_1300m , "1300M-Model-Finetuned" : category_accs_1300m_ft , "2700M-Model" : category_accs_2700m }
183- make_spider_plot (benchmark_data )
162+ # benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "1300M-Model-Finetuned": category_accs_1300m_ft, "2700M-Model": category_accs_2700m}
163+ # make_spider_plot(benchmark_data)
0 commit comments