### Load Model From huggingface import os import tqdm import joblib import numpy as np import pandas as pd import torch from transformers import AutoTokenizer, AutoModel import wandb import peft import loralib as lora from peft import LoraConfig import json from torch.utils.data import DataLoader from torch.utils.data import Dataset from accelerate import Accelerator, DeepSpeedPlugin from transformers import get_linear_schedule_with_warmup """ extra requirements: pip install icetk """ checkpoint = "/model/chatglm-6b" datafile='datasets/merge.json' out_dir= 'outs/chatglm-6b' use_wandb=True mixed_precision = 'bf16' accumulate_step = 8 log_interval = 100 Per_GPU_BATCH_SIZE = 2 MAX_LENGTH = 256 # have huge impact on VRAM: 968:1, 256:4 config = LoraConfig( peft_type="LORA", r=32, lora_alpha=32, target_modules=["q", "k", "v"], lora_dropout=0.1, ) LR = 2e-5 NUM_EPOCHS = 3 warm_up_ratio = 0.1 device_map = "auto" world_size = int(os.environ.get("WORLD_SIZE", 1)) ddp = world_size != 1 if ddp: device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if use_wandb: wandb.init( project="LoRA", name=f"{checkpoint}-{datafile}", config=None, ) else: wandb.init(mode='disabled') os.makedirs(out_dir, exist_ok=True) tokenizer = AutoTokenizer.from_pretrained( checkpoint, trust_remote_code=True, device_map=device_map, ) # BUG: must remove special token '[MASK]' # del tokenizer.vocab['MASK'] ### Dataset EOS_ID = 150005 PROMPT_DICT = { "prompt_input": ( "Below is an instruction that describes a task, paired with an input that provides further context. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" ), "prompt_no_input": ( "Below is an instruction that describes a task. " "Write a response that appropriately completes the request.\n\n" "### Instruction:\n{instruction}\n\n### Response:" ), } with open(datafile, 'r') as f: content = json.load(f) pairs = [] for line in content: if line['input'] == '': prompt = PROMPT_DICT['prompt_no_input'].format_map(line) else: prompt = PROMPT_DICT['prompt_input'].format_map(line) completion = line['output']+'' if len(prompt) + len(completion) < MAX_LENGTH: pairs.append({'prompt':prompt, 'completion':completion}) class AlpacaDataset(Dataset): def __init__(self, pairs, tokenizer) -> None: super().__init__() self.pairs = pairs self.tokenizer = tokenizer def __getitem__(self, index): if self.pairs[index]['completion'][-4:] == '': prompt = self.tokenizer.encode(self.pairs[index]['prompt']) completion = self.tokenizer.encode(self.pairs[index]['completion'][:-4], add_special_tokens=False) completion += [EOS_ID] else: prompt = self.tokenizer.encode(self.pairs[index]['prompt']) completion = self.tokenizer.encode(self.pairs[index]['completion'], add_special_tokens=False) if 150001 not in prompt: prompt = self.pairs[index]['prompt'].replace('[MASK]', '//MASK//').replace('[gMASK]', '//gMASK//') completion = self.pairs[index]['completion'].replace('[MASK]', '//MASK//').replace('[gMASK]', '//gMASK//') prompt = self.tokenizer.encode(prompt) completion = self.tokenizer.encode(completion, add_special_tokens=False) if 150001 not in prompt: import pdb; pdb.set_trace() return {'prompt':prompt, 'completion':completion} def __len__(self): return len(self.pairs) def collate_fn(batch): input_ids = [] labels = [] position_ids = [] device='cuda:0' _max_length = max([len(obj['prompt'])+len(obj['completion']) for obj in batch]) attention_mask = torch.ones((len(batch), _max_length, _max_length), device=device) attention_mask.tril_() for i, obj in enumerate(batch): context_length = obj['prompt'].index(150004) attention_mask[i, :, :context_length] = 1 to_pad = _max_length - len(obj['prompt']) - len(obj['completion']) input_ids.append(obj['prompt'] + obj['completion'] + [tokenizer.pad_token_id] * to_pad) position_ids.append(torch.stack( [torch.arange(0, _max_length, device=device), torch.concat([torch.zeros(context_length - 1, device=device), torch.arange(0, _max_length - context_length + 1, device=device)])]).long() ) labels.append(torch.tensor([-100] * len(obj['prompt']) + obj['completion'] + [-100] * to_pad, device=device).long()) attention_mask.unsqueeze_(1) attention_mask = (attention_mask < 0.5).bool() return {'input_ids': torch.tensor(input_ids).long(), 'attention_mask': attention_mask, 'labels': torch.stack(labels), 'position_ids':torch.stack(position_ids)} train_dataset = AlpacaDataset(pairs,tokenizer=tokenizer,) train_dataloader = DataLoader(dataset=train_dataset, collate_fn = collate_fn, shuffle=True, batch_size=Per_GPU_BATCH_SIZE) # check for step, batch in enumerate(t:=tqdm.tqdm(train_dataloader)): pass model = AutoModel.from_pretrained( checkpoint, trust_remote_code=True, ) deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=accumulate_step) accelerator = Accelerator(mixed_precision=mixed_precision, gradient_accumulation_steps=accumulate_step, deepspeed_plugin=deepspeed_plugin) device = accelerator.device ### Insert LoRA to model class QKV_layer(torch.nn.Module): def __init__(self, in_features, out_features): super(QKV_layer, self).__init__() self.linear_q = torch.nn.Linear(in_features, out_features//3) self.linear_k = torch.nn.Linear(in_features, out_features//3) self.linear_v = torch.nn.Linear(in_features, out_features//3) def update(self, target_layer): self.linear_q.weight.data = target_layer.weight[:target_layer.out_features//3, :].data self.linear_q.bias.data = target_layer.bias[:target_layer.out_features//3].data self.linear_k.weight.data = target_layer.weight[target_layer.out_features//3:target_layer.out_features//3*2, :].data self.linear_k.bias.data = target_layer.bias[target_layer.out_features//3:target_layer.out_features//3*2].data self.linear_v.weight.data = target_layer.weight[target_layer.out_features//3*2:, :].data self.linear_v.bias.data = target_layer.bias[target_layer.out_features//3*2:].data def forward(self, x): q = self.linear_q(x) k = self.linear_k(x) v = self.linear_v(x) return torch.concat([q,k,v], dim = -1) for key, module in model.named_modules(): if key.endswith('attention'): if isinstance(module.query_key_value, peft.tuners.lora.LoraModel): module.query_key_value = peft.tuners.lora.LoraModel(config, module.query_key_value.model) else: # Here we split the query_key_value layer into three linear layer for LoRA. But you can also use merged linear. qkv_layer = QKV_layer(module.query_key_value.in_features, module.query_key_value.out_features) qkv_layer.update(module.query_key_value) module.query_key_value = qkv_layer module.query_key_value = peft.tuners.lora.LoraModel(config, module.query_key_value) lora.mark_only_lora_as_trainable(model) model_parameters = filter(lambda p: p.requires_grad, model.parameters()) trainable_params = sum([np.prod(p.size()) for p in model_parameters]) non_trainable_params = sum([np.prod(p.size()) for p in model_parameters]) print('trainable_params:{} ({:.2f}%), non_trainable_params:{}'.format( trainable_params, trainable_params/non_trainable_params*100,non_trainable_params )) ### Training optimizer = torch.optim.AdamW(model.parameters(), lr=LR) lr_scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=int(len(train_dataloader) / accumulate_step * warm_up_ratio), num_training_steps=(int(len(train_dataloader) / accumulate_step) * NUM_EPOCHS), ) model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader) model.to(device).train() for epoch in range(NUM_EPOCHS): total_loss = 0 for step, batch in enumerate(t:=tqdm.tqdm(train_dataloader)): with accelerator.accumulate(model): outputs = model(**batch) loss_detach = outputs.loss.detach().cpu().float() # t.set_description(f"loss: {loss_detach}") t.set_postfix(loss=loss_detach.item()) total_loss += loss_detach loss = outputs.loss if accelerator.is_main_process: if step % log_interval == 0: wandb.log({ 'train/loss': loss_detach.item(), }) accelerator.backward(loss) optimizer.step() lr_scheduler.step() optimizer.zero_grad() accelerator.wait_for_everyone() if accelerator.is_main_process: peft_model_id = f"finetune_{epoch}" accelerator.save(lora.lora_state_dict(accelerator.unwrap_model(model)), f'{out_dir}/{peft_model_id}.pt')