11"""
22Drafting lab flow in script format using PyTorch
33"""
4-
4+ from datasets import load_dataset
5+ import math
56import numpy as np
67import pandas as pd
78import tensorflow as tf
9+ import torch
10+ import torch .nn as nn
11+ import torch .nn .functional as F
812import transformers
13+ from trl import SFTTrainer
914
1015from utils import run_benchmark , make_spider_plot
1116
@@ -63,14 +68,14 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
6368# Benchmark smaller model
6469model_name_350m = "facebook/opt-350m"
6570model_350m = transformers .AutoModelForCausalLM .from_pretrained (model_name_350m , device_map = "auto" )
66- tokenizer_350m = transformers .AutoTokenizer .from_pretrained (model_350m )
71+ tokenizer_350m = transformers .AutoTokenizer .from_pretrained (model_name_350m )
6772
6873category_accs_350m , avg_acc_350m = run_benchmark (model_350m , tokenizer_350m , dataset )
6974
7075# Benchmark larger model
7176model_name_2700m = "facebook/opt-2.7b"
7277model_2700m = transformers .AutoModelForCausalLM .from_pretrained (model_name_2700m , device_map = "auto" )
73- tokenizer_2700m = transformers .AutoTokenizer .from_pretrained (model_2700m )
78+ tokenizer_2700m = transformers .AutoTokenizer .from_pretrained (model_name_2700m )
7479
7580category_accs_2700m , avg_acc_2700m = run_benchmark (model_2700m , tokenizer_2700m , dataset )
7681
@@ -81,16 +86,98 @@ def generate(start_text, model, tokenizer, num_steps=20, temp=1.):
8186
8287# Part 2
8388
84- # new LoRA linear layer class
85-
86- # new attention layer class
89+ # inspect current model
90+ print (model )
8791
88- # replace attention modules with new module
92+ # new LoRA linear layer class
93+ class LoRALinear (nn .Linear ):
94+ def __init__ (
95+ self ,
96+ in_features : int ,
97+ out_features : int ,
98+ r : int = 8 ,
99+ lora_alpha : int = 1 ,
100+ ** kwargs
101+ ):
102+ self .r = r
103+ self .in_features = in_features
104+ self .out_features = out_features
105+ self .lora_alpha = lora_alpha
106+
107+ nn .Linear .__init__ (self , in_features , out_features , ** kwargs )
108+
109+ # from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
110+ if r > 0 :
111+ self .lora_A = nn .Parameter (self .weight .new_zeros ((r , in_features )))
112+ self .lora_B = nn .Parameter (self .weight .new_zeros ((out_features , r )))
113+ self .scaling = self .lora_alpha / self .r
114+ # Freezing the pre-trained weight matrix
115+ self .weight .requires_grad = False
116+ self .reset_parameters ()
117+
118+ def reset_parameters (self ):
119+ # from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
120+ nn .Linear .reset_parameters (self )
121+ if hasattr (self , 'lora_A' ):
122+ # initialize B the same way as the default for nn.Linear and A to zero
123+ # this is different than what is described in the paper but should not affect performance
124+ nn .init .kaiming_uniform_ (self .lora_A , a = math .sqrt (5 ))
125+ nn .init .zeros_ (self .lora_B )
126+
127+ def forward (self , x : torch .Tensor ):
128+ if self .r > 0 :
129+ result = F .linear (x , self .weight , bias = self .bias )
130+ result += (x @ self .lora_A .transpose (0 , 1 ) @ self .lora_B .transpose (0 , 1 )) * self .scaling
131+ return result
132+ else :
133+ return F .linear (x , self .weight , bias = self .bias )
134+
135+ # replace linear layers in model recursively
136+ def replace_linear_with_lora (module ):
137+ for name , child in module .named_children ():
138+ if isinstance (child , nn .Linear ):
139+ setattr (module , name , LoRALinear (child .in_features , child .out_features ))
140+ else :
141+ replace_linear_with_lora (child )
142+
143+ replace_linear_with_lora (model )
144+
145+ # inspect new model
146+ print (model )
89147
90148# load chat dataset
149+ dataset_name = "timdettmers/openassistant-guanaco"
150+ ft_dataset = load_dataset (dataset_name , split = "train" )
91151
92152# train model
153+ log_dir = "/scratch/checkpoints/test-sft/opt1.3b_768/"
154+ batch_size = 4
155+ context_length = 768
156+ args = transformers .TrainingArguments (log_dir ,
157+ per_device_train_batch_size = batch_size ,
158+ logging_first_step = True ,
159+ logging_steps = 20 ,
160+ save_steps = 100 ,
161+ )
162+
163+ class PrinterCallback (transformers .TrainerCallback ):
164+ def on_log (self , args , state , control , model , logs = None , ** kwargs ):
165+ start_text = "### Human: When the weather is sunny, what color is the sky?### Assistant:"
166+ generate (start_text , model , tokenizer , num_steps = 200 , until = "###" )
167+
168+ trainer = SFTTrainer (
169+ model ,
170+ args = args ,
171+ train_dataset = ft_dataset ,
172+ dataset_text_field = "text" ,
173+ max_seq_length = context_length ,
174+ callbacks = [PrinterCallback ()]
175+ )
176+ trainer .train ()
93177
94178# evaluate finetuned model on benchmark
179+ category_accs_1300m_ft , avg_acc_1300m_ft = run_benchmark (model , tokenizer , dataset )
95180
96- # add to spider plot
181+ # add to spider plot
182+ benchmark_data = {"350M-Model" : category_accs_350m , "1300M-Model" : category_accs_1300m , "1300M-Model-Finetuned" : category_accs_1300m_ft , "2700M-Model" : category_accs_2700m }
183+ make_spider_plot (benchmark_data )
0 commit comments