|
from transformers import AutoModel, AutoTokenizer, AutoConfig, AdamW, get_linear_schedule_with_warmup |
|
from torch.utils.data import DataLoader |
|
import transformers |
|
from sklearn.model_selection import train_test_split |
|
from datasets import load_dataset, DatasetDict |
|
import torch.nn as nn |
|
import torch |
|
import wandb |
|
from tqdm import tqdm |
|
|
|
args_max_epoch = 1 |
|
args_batch_size = 64 |
|
args_learning_rate = 3e-5 |
|
args_num_warmup_steps = 100 |
|
args_gradient_accumulation_steps_default = 2 |
|
adapter_hidden_dim = 4096 |
|
|
|
device = 'cuda' |
|
|
|
|
|
def main(): |
|
wandb.init(project="MappingAdapater_training_v6", name="training_run") |
|
|
|
model = MappingStructure(checkpointE = "sentence-transformers/stsb-roberta-large", |
|
checkpointD = "mistralai/Mistral-7B-Instruct-v0.1", |
|
hidden_dim = adapter_hidden_dim, |
|
torch_dtype = torch.float16, |
|
flash_attn = True, |
|
).to(device) |
|
|
|
for n,p in model.named_parameters(): |
|
if 'mapping' not in n: |
|
p.requires_grad = False |
|
else: |
|
p.requires_grad = True |
|
|
|
dataset = load_dataset("sade-adrien/redpajama_v2_sample_10M")['train'] |
|
train_dataset, val_dataset = split_dataset(dataset, train_size=.989333) |
|
datasets = DatasetDict({ |
|
'train': train_dataset, |
|
'val': val_dataset |
|
}) |
|
|
|
train_dataloader = DataLoader(datasets['train'], batch_size=args_batch_size, shuffle=True) |
|
val_dataloader = DataLoader(datasets['val'], batch_size=args_batch_size, shuffle=False) |
|
|
|
optimizer = AdamW(model.parameters(), lr=args_learning_rate) |
|
scheduler = get_linear_schedule_with_warmup(optimizer, args_num_warmup_steps, args_max_epoch*len(train_dataloader)) |
|
|
|
global_step = 0 |
|
for epoch in range(args_max_epoch): |
|
train_dataloader = DataLoader(datasets['train'], batch_size=args_batch_size, shuffle=True, worker_init_fn=lambda _: torch.manual_seed(epoch)) |
|
|
|
for batch in tqdm(train_dataloader): |
|
input_prompt = batch['raw_content'] |
|
outputs = model(input_prompt=input_prompt, compute_loss=True) |
|
loss = outputs['loss'] |
|
|
|
|
|
loss = loss / args_gradient_accumulation_steps_default |
|
loss.backward() |
|
|
|
if (global_step + 1) % args_gradient_accumulation_steps_default == 0: |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
scheduler.step() |
|
|
|
|
|
if (global_step + 1) % 2000 == 0: |
|
torch.save({ |
|
'epoch': epoch, |
|
'mapping_state_dict': model.mapping.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'scheduler_state_dict': scheduler.state_dict(), |
|
'global_step': global_step, |
|
}, f'models/mapping_adapter_checkpoint_{global_step + 1}steps.pth') |
|
|
|
global_step += 1 |
|
val_loss = None |
|
if (global_step + 1) % 8000 == 0: |
|
model.eval() |
|
val_loss = 0.0 |
|
with torch.no_grad(): |
|
for val_batch in tqdm(val_dataloader): |
|
val_inputs = val_batch['raw_content'] |
|
val_outputs = model(input_prompt=val_inputs, compute_loss=True) |
|
val_loss += val_outputs['loss'] |
|
val_loss /= len(val_dataloader) |
|
|
|
model.train() |
|
|
|
wandb.log({ |
|
'step': global_step + 1, |
|
'learning_rate': scheduler.get_last_lr()[0], |
|
'train_loss': loss.item() * args_gradient_accumulation_steps_default, |
|
'val_loss': val_loss.item() if val_loss else None |
|
}) |
|
|
|
|
|
|
|
|
|
def split_dataset(dataset, train_size=.9): |
|
index = int(len(dataset) * train_size) |
|
return dataset.select(range(index)), dataset.select(range(index, len(dataset))) |
|
|
|
class MappingAdapter(nn.Module): |
|
def __init__(self, input_dim, output_dim, hidden_dim): |
|
super(MappingAdapter, self).__init__() |
|
self.layer1 = nn.Linear(input_dim, hidden_dim) |
|
self.layer2 = nn.Linear(hidden_dim, output_dim) |
|
self.activation = nn.LeakyReLU(.01) |
|
|
|
def forward(self, x): |
|
x = self.layer1(x) |
|
x = self.activation(x) |
|
x = self.layer2(x) |
|
return x |
|
|
|
class MappingStructure(nn.Module): |
|
def __init__(self, checkpointE, checkpointD, hidden_dim=2048, torch_dtype=torch.float32, flash_attn=False): |
|
super(MappingStructure, self).__init__() |
|
|
|
self.configE = AutoConfig.from_pretrained(checkpointE) |
|
self.Encoder = AutoModel.from_pretrained(checkpointE, |
|
low_cpu_mem_usage = True, |
|
torch_dtype = torch_dtype, |
|
config = self.configE |
|
) |
|
|
|
self.configD = AutoConfig.from_pretrained(checkpointD) |
|
if flash_attn: |
|
self.configD.update({'_flash_attn_2_enabled' : True}) |
|
self.Decoder = AutoModel.from_pretrained(checkpointD, |
|
low_cpu_mem_usage = True, |
|
torch_dtype = torch_dtype, |
|
config = self.configD |
|
) |
|
|
|
self.mapping = MappingAdapter(self.configD.hidden_size, self.configE.hidden_size, hidden_dim=hidden_dim).to(torch_dtype) |
|
|
|
self._init_tokenizers(checkpointE, checkpointD) |
|
|
|
def _init_tokenizers(self, checkpointE, checkpointD): |
|
self.tokenizerE = AutoTokenizer.from_pretrained(checkpointE, use_fast = False, revision = 'main', config = self.configE, padding_side='left') |
|
self.tokenizerD = AutoTokenizer.from_pretrained(checkpointD, use_fast = False, revision = 'main', config = self.configD, padding_side='left') |
|
self.tokenizerD.pad_token_id = self.tokenizerD.unk_token_id |
|
|
|
def cosine_sim(self, u, v): |
|
assert u.shape == v.shape, "u and v must have the same shape" |
|
u_normalized = u / torch.norm(u, dim=1, keepdim=True) |
|
v_normalized = v / torch.norm(v, dim=1, keepdim=True) |
|
|
|
|
|
return torch.sum(u_normalized * v_normalized, dim=1) |
|
|
|
|
|
def mean_pooling(self, hidden_state, attention_mask): |
|
token_embeddings = hidden_state |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
|
|
def build_batch(self, input_prompt): |
|
size = torch.randint(1, self.configE.max_position_embeddings-2, (1,)).item() |
|
targets = [] |
|
|
|
for prompt in input_prompt: |
|
tokenized_input = self.tokenizerE(prompt) |
|
tokenized_input = {'input_ids': tokenized_input['input_ids'][:size], |
|
'attention_mask': tokenized_input['attention_mask'][:size], |
|
|
|
} |
|
targets.append(tokenized_input) |
|
targets = self.tokenizerE.pad(targets, padding=True, return_tensors='pt') |
|
|
|
return targets |
|
|
|
|
|
def forward(self, input_prompt, compute_loss=False): |
|
loss = None |
|
|
|
|
|
if not compute_loss: |
|
inputs = self.tokenizerD(input_prompt, return_tensors='pt', padding=True).to(device) |
|
|
|
hidden_state_D = self.Decoder(**inputs).last_hidden_state |
|
hidden_state_D_mapped = self.mapping(hidden_state_D) |
|
|
|
else: |
|
targets = self.build_batch(input_prompt).to(device) |
|
|
|
input_prompt_sliced = self.tokenizerE.batch_decode(targets['input_ids'], skip_special_tokens=True) |
|
inputs = self.tokenizerD(input_prompt_sliced, return_tensors='pt', padding=True).to(device) |
|
|
|
hidden_state_D = self.Decoder(**inputs).last_hidden_state |
|
hidden_state_D_mapped = self.mapping(hidden_state_D) |
|
|
|
hidden_state_E = self.Encoder(**targets).last_hidden_state |
|
|
|
proj_E = self.mean_pooling(hidden_state_E, targets['attention_mask']) |
|
proj_D = self.mean_pooling(hidden_state_D_mapped, inputs['attention_mask']) |
|
|
|
loss = 1 - torch.mean(self.cosine_sim(proj_E, proj_D)) |
|
|
|
del inputs |
|
del targets |
|
del input_prompt_sliced |
|
del hidden_state_E |
|
del proj_E |
|
del proj_D |
|
torch.cuda.empty_cache() |
|
|
|
return {'loss': loss, |
|
'last_hidden_state': hidden_state_D, |
|
'last_hidden_state_mapped': hidden_state_D_mapped, |
|
} |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |