1212from torch .nn import CrossEntropyLoss
1313from torch .optim import Adam
1414import transformers
15+ from trl import SFTTrainer
1516from tqdm import tqdm
1617
1718from utils import run_benchmark , make_spider_plot
@@ -54,6 +55,40 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
5455 output = tokenizer .decode (x [num_start :])
5556 return output
5657
58+ def generate_pt (model , tokenizer , text , num_steps = 50 , until = None , temp = 1. ):
59+ device = model .device
60+ print (text , end = '' , flush = True )
61+ x = tokenizer .encode (text )
62+ enc_until = tokenizer .encode (until )[1 :]
63+ num_start = len (x )
64+
65+ decoded = tokenizer .decode (x )
66+
67+ for step in range (num_steps ):
68+ with torch .no_grad ():
69+ input_tensor = torch .reshape (torch .LongTensor (x ), [1 , - 1 ]).to (device )
70+ logits = model (input_tensor ).logits
71+ probs = F .softmax (logits / temp , dim = - 1 )[0 , - 1 , :]
72+ probs = probs .detach ().cpu ().numpy ()
73+
74+ new_token = np .random .choice (len (probs ), p = probs )
75+ x .append (new_token )
76+
77+ new_decoded = tokenizer .decode (x )
78+ new_part = new_decoded [len (decoded ):]
79+ decoded = new_decoded
80+
81+ print (new_part , end = '' , flush = True )
82+ text += new_part
83+
84+ if len (x ) >= len (until ) and text [- len (until ):] == until :
85+ break
86+
87+
88+ output = tokenizer .decode (x [num_start :])
89+ print ("\n " , flush = True )
90+ return output
91+
5792# Test autoregressive generation
5893# while True:
5994# print("\n\n\n\n\n")
@@ -87,13 +122,30 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
87122# benchmark_data = {"350M-Model": category_accs_1300m}
88123# make_spider_plot(benchmark_data)
89124
125+ def print_lora_params (module , layer_type ):
126+ summ = 0
127+ for name , child in module .named_children ():
128+ if isinstance (child , layer_type ):
129+ num_params = sum (p .numel () for p in child .parameters () if p .requires_grad )
130+
131+ print (name , num_params , child .in_features , child .out_features , (child .in_features * 8 + child .out_features * 8 == num_params ))
132+
133+ summ += num_params
134+ else :
135+ summ += print_lora_params (child , layer_type )
136+
137+ return summ
138+
90139# Part 2
91140
92141# inspect current model
93142# print(model)
94- layer = model .lm_head
95- print (layer .weight .shape )
96- print (sum (p .numel () for p in layer .parameters () if p .requires_grad ))
143+
144+ # summ = print_lora_params(model, nn.Linear)
145+
146+ # print("with function", summ)
147+
148+ # print("without function", sum(p.numel() for p in model.parameters() if p.requires_grad))
97149
98150# # freeze all parameter gradients
99151for param in model .parameters ():
@@ -149,8 +201,14 @@ def replace_linear_with_lora(module):
149201
150202replace_linear_with_lora (model )
151203
152- layer = model .lm_head
153- print (sum (p .numel () for p in layer .parameters () if p .requires_grad ))
204+
205+
206+ # summ = print_lora_params(model, LoRALinear)
207+
208+ # print("with function", summ)
209+
210+ # print("without function", sum(p.numel() for p in model.parameters() if p.requires_grad))
211+
154212
155213# inspect new model
156214# print(model)
@@ -169,47 +227,73 @@ def replace_linear_with_lora(module):
169227
170228model = model .to ("cuda" )
171229
172-
173- for epoch in range (num_epochs ):
174- total_loss = 0
175- num_batches = 0
176-
177- for batch in tqdm (ft_dataset ):
178- prompt = batch ["text" ]
230+ ### Train the model
231+ # Define some training args
232+ args = transformers .TrainingArguments ("/home/dnori/introtodeeplearning/xtra_labs/llm_finetune/outputs" ,
233+ per_device_train_batch_size = 1 ,
234+ logging_first_step = True ,
235+ logging_steps = 20 ,
236+ save_steps = 100 ,
237+ )
238+
239+ # Define a callback to check the progress on a sample question
240+ class PrinterCallback (transformers .TrainerCallback ):
241+ def on_log (self , args , state , control , model , logs = None , ** kwargs ):
242+ start_text = "### Human: When the weather is sunny, what color is the sky?### Assistant:"
243+ generate_pt (model , tokenizer , start_text , num_steps = 200 , until = "###" )
244+
245+ # Actually train the model
246+ trainer = SFTTrainer (
247+ model ,
248+ args = args ,
249+ train_dataset = ft_dataset ,
250+ dataset_text_field = "text" ,
251+ max_seq_length = context_length ,
252+ callbacks = [PrinterCallback ()]
253+ )
254+ trainer .train ()
255+
256+
257+ # for epoch in range(num_epochs):
258+ # total_loss = 0
259+ # num_batches = 0
260+
261+ # for batch in tqdm(ft_dataset):
262+ # prompt = batch["text"]
179263
180- # encode with tokenizer
181- x = tokenizer .encode (prompt )
182- x_tensor = torch .tensor (x ).view (1 , - 1 ).to ("cuda" )
183- max_len = min (context_length , x_tensor .shape [1 ]- 1 )
184- selected_len = random .randint (1 ,max_len )
264+ # # encode with tokenizer
265+ # x = tokenizer.encode(prompt)
266+ # x_tensor = torch.tensor(x).view(1, -1).to("cuda")
267+ # max_len = min(context_length, x_tensor.shape[1]-1)
268+ # selected_len = random.randint(1,max_len)
185269
186- input_tensor = x_tensor [:,:selected_len ]
187- target_tensor = x_tensor [0 ,1 :selected_len + 1 ]
270+ # input_tensor = x_tensor[:,:selected_len]
271+ # target_tensor = x_tensor[0,1:selected_len+1]
188272
189- # zero gradients
190- optimizer .zero_grad ()
273+ # # zero gradients
274+ # optimizer.zero_grad()
191275
192- # run through model
193- logits = model (input_tensor ).logits [0 ]
276+ # # run through model
277+ # logits = model(input_tensor).logits[0]
194278
195- # apply loss
196- loss = loss_fn (logits , target_tensor )
279+ # # apply loss
280+ # loss = loss_fn(logits, target_tensor)
197281
198- # backpropagation
199- loss .backward ()
282+ # # backpropagation
283+ # loss.backward()
200284
201- # optimizer step
202- optimizer .step ()
285+ # # optimizer step
286+ # optimizer.step()
203287
204- total_loss += loss .item ()
205- num_batches += 1
288+ # total_loss += loss.item()
289+ # num_batches += 1
206290
207- # Print average loss for the epoch
208- average_loss = total_loss / num_batches
209- print (f"Epoch { epoch + 1 } /{ num_epochs } , Loss: { average_loss } " )
291+ # # Print average loss for the epoch
292+ # average_loss = total_loss / num_batches
293+ # print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss}")
210294
211- # evaluate finetuned model on benchmark
212- category_accs_1300m_ft , avg_acc_1300m_ft = run_benchmark (model , tokenizer , benchmark_dataset )
295+ # # evaluate finetuned model on benchmark
296+ # category_accs_1300m_ft, avg_acc_1300m_ft = run_benchmark(model, tokenizer, benchmark_dataset)
213297
214298# add to spider plot
215299# benchmark_data = {"350M-Model": category_accs_350m, "1300M-Model": category_accs_1300m, "1300M-Model-Finetuned": category_accs_1300m_ft, "2700M-Model": category_accs_2700m}
0 commit comments