|
import os |
|
import deepspeed |
|
|
|
from tqdm import tqdm |
|
import shutil |
|
os.environ['HF_ENDPOINT']="https://hf-mirror.com" |
|
from qwenva import tokenizer |
|
from qwenva import processor |
|
from qwenva import qwenva |
|
images_file_path='./data/download/llava-v1.5-instruct' |
|
|
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
import os |
|
import json |
|
from PIL import Image |
|
import json |
|
with open('/root/autodl-tmp/LLaVA-Instruct-150K/qwenva_mix665k.json', 'r', encoding='utf-8') as f: |
|
chat_data = json.load(f) |
|
import torch |
|
image_token=tokenizer.encode('<image>')[0] |
|
pad_token=tokenizer.pad_token_id |
|
image_token=tokenizer.encode('<image>')[0] |
|
pad_token=tokenizer.pad_token_id |
|
def process_data(sample,max_len=8012): |
|
conversations=sample['conversations'] |
|
labels=[] |
|
input_ids=[] |
|
flag=0 |
|
messages=[] |
|
input_ids=[] |
|
try: |
|
for index,item in enumerate(conversations): |
|
if item['from']=='human': |
|
old_input_ids=input_ids |
|
messages.append({'role':'user','content':item['value']}) |
|
input_ids=tokenizer.apply_chat_template( |
|
messages, |
|
add_generation_prompt=True |
|
) |
|
|
|
labels+=[-100]*(len(input_ids)-len(old_input_ids)) |
|
if index==flag: |
|
if image_token in input_ids: |
|
image_index=input_ids.index(image_token) |
|
labels[image_index]=image_token |
|
else: |
|
image_index=-100 |
|
elif item['from']=='gpt': |
|
old_input_ids=input_ids |
|
messages.append({'role':'assistant','content':item['value']}) |
|
input_ids=tokenizer.apply_chat_template( |
|
messages |
|
) |
|
labels+=input_ids[len(old_input_ids):] |
|
except: |
|
print("error in process_data_1") |
|
exit() |
|
|
|
try: |
|
if len(input_ids)>max_len: |
|
input_ids=input_ids[:max_len] |
|
labels=labels[:max_len] |
|
attention_mask=[1]*len(input_ids) |
|
else: |
|
attention_mask=[1]*len(input_ids)+[0]*(max_len-len(input_ids)) |
|
input_ids+=[pad_token]*(max_len-len(input_ids)) |
|
labels+=[-100]*(max_len-len(labels)) |
|
except: |
|
print("error in process_data_2") |
|
exit() |
|
|
|
try: |
|
input_ids=torch.tensor(input_ids) |
|
attention_mask=torch.tensor(attention_mask) |
|
labels=torch.tensor(labels) |
|
image_index=torch.tensor(image_index) |
|
except: |
|
print("error in tensor") |
|
exit() |
|
return { |
|
'input_ids':input_ids, |
|
'attention_mask':attention_mask, |
|
'labels':labels, |
|
'image_idx':image_index |
|
} |
|
|
|
|
|
import os |
|
import torch |
|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
class MyDataset(Dataset): |
|
def __init__(self, images_file_path,data,max_len=1024): |
|
self.max_len=max_len |
|
self.images_file_path = images_file_path |
|
self.data = data |
|
self.max_len=max_len |
|
def __len__(self): |
|
return len(self.data) |
|
def __getitem__(self, index): |
|
output_=process_data(self.data[index],max_len=self.max_len) |
|
if output_['image_idx']!=-100: |
|
img_path=os.path.join(self.images_file_path,self.data[index]['image']) |
|
img=Image.open(img_path) |
|
input_pixel= processor(images=img, return_tensors="pt") |
|
output_['input_pixel']=input_pixel['pixel_values'].squeeze() |
|
else: |
|
output_['input_pixel']=torch.zeros(3,224,224).to(device=output_['input_ids'].device,dtype=output_['input_ids'].dtype) |
|
return output_ |
|
|
|
|
|
|
|
dataset=MyDataset(images_file_path,chat_data,max_len=2048) |
|
train_loader=DataLoader(dataset,batch_size=16,shuffle=True) |
|
import argparse |
|
model_engine,optimizer,_,_=deepspeed.initialize( |
|
model=qwenva, |
|
args=argparse.Namespace(), |
|
model_parameters=qwenva.parameters(), |
|
config_params="./deepspeed_config.json" |
|
) |
|
|
|
|
|
|
|
|
|
for name, param in model_engine.module._orig_mod.text_embedding.named_parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
|
|
|
|
|
|
for name,param in model_engine.module._orig_mod.lm_head.named_parameters(): |
|
param.requires_grad = True |
|
|
|
|
|
for name,param in model_engine.module._orig_mod.transformer.named_parameters(): |
|
param.requires_grad = True |
|
|
|
for name,param in model_engine.module._orig_mod.named_parameters(): |
|
if param.requires_grad: |
|
print(f"Layer: {name}, Requires Grad: {param.requires_grad}") |
|
|
|
|
|
|
|
import torch.nn as nn |
|
loss_fn = nn.CrossEntropyLoss() |
|
|
|
accumulation_steps = 1 |
|
|
|
def train(model_engine, train_dataloader, loss_fn, device, epochs): |
|
model_engine.train() |
|
|
|
for epoch in range(epochs): |
|
|
|
with tqdm(total=len(train_dataloader), desc=f'Epoch {epoch + 1}/{epochs}', unit='batch') as pbar: |
|
|
|
try: |
|
for batch_idx, batch in enumerate(train_dataloader): |
|
|
|
input_ids = batch['input_ids'].to(device) |
|
attention_mask = batch['attention_mask'].to(device) |
|
input_pixel = batch['input_pixel'].to(device) |
|
labels = batch['labels'].to(device) |
|
image_idx=batch['image_idx'].to(device) |
|
logits = model_engine(input_ids, attention_mask, input_pixel,image_idx) |
|
|
|
|
|
|
|
loss= loss_fn(logits[:, :-1, :].reshape(-1, logits.shape[-1]), labels[:, 1:].reshape(-1).clone()) |
|
model_engine.backward(loss) |
|
if (batch_idx+1)%accumulation_steps==0: |
|
model_engine.step() |
|
pbar.update(1) |
|
pbar.set_postfix(loss=loss.item()) |
|
if (batch_idx+1)%4100==0: |
|
|
|
if os.path.exists("./best_model_2"): |
|
shutil.rmtree("./best_model_2") |
|
os.makedirs("./best_model_2") |
|
model_engine.save_checkpoint("./best_model_2") |
|
torch.save(model_engine.module.state_dict(), "./compiled_model_3.pth") |
|
print(f" model saved at batch {batch_idx+1}") |
|
except Exception as e: |
|
print(f"error in train {e}") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
train(model_engine, train_loader, loss_fn, device, epochs=2) |
|
|