Safetensors
jiangchengchengNLP commited on
Commit
4ab846c
·
verified ·
1 Parent(s): d3fa7ee

Upload 5 files

Browse files

files for train

data_download.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ download.py
3
+
4
+ Utility functions for downloading and extracting various datasets to (local) disk.
5
+ """
6
+
7
+ import os
8
+ import shutil
9
+ from pathlib import Path
10
+ from typing import Dict, List, TypedDict
11
+ from zipfile import ZipFile
12
+
13
+ import requests
14
+ from PIL import Image
15
+ from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn
16
+ from tqdm import tqdm
17
+
18
+ #from prismatic.overwatch import initialize_overwatch
19
+
20
+ # Initialize Overwatch =>> Wraps `logging.Logger`
21
+ #overwatch = initialize_overwatch(__name__)
22
+
23
+
24
+ # === Dataset Registry w/ Links ===
25
+ # fmt: off
26
+ DatasetComponent = TypedDict(
27
+ "DatasetComponent",
28
+ {"name": str, "extract": bool, "extract_type": str, "url": str, "do_rename": bool},
29
+ total=False
30
+ )
31
+
32
+ DATASET_REGISTRY: Dict[str, List[DatasetComponent]] = {
33
+ # === LLaVa v1.5 Dataset(s) ===
34
+
35
+ # Note =>> This is the full suite of datasets included in the LLaVa 1.5 "finetuning" stage; all the LLaVa v1.5
36
+ # models are finetuned on this split. We use this dataset for all experiments in our paper.
37
+ "llava-v1.5-instruct":
38
+ [
39
+ {
40
+ "name": "coco/train2017", # Visual Instruct Tuning images are all sourced from COCO Train 2017
41
+ "extract": True,
42
+ "extract_type": "directory",
43
+ "url": "http://images.cocodataset.org/zips/train2017.zip",
44
+ "do_rename": True,
45
+ },
46
+ {
47
+ "name": "gqa/images",
48
+ "extract": True,
49
+ "extract_type": "directory",
50
+ "url": "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip",
51
+ "do_rename": True,
52
+ },
53
+ {
54
+ "name": "ocr_vqa/images",
55
+ "extract": True,
56
+ "extract_type": "directory",
57
+ "url": "https://hf-mirror.com/datasets/qnguyen3/ocr_vqa/resolve/main/ocr_vqa.zip",
58
+ "do_rename": True,
59
+ },
60
+ {
61
+ "name": "textvqa/train_images",
62
+ "extract": True,
63
+ "extract_type": "directory",
64
+ "url": "https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip",
65
+ "do_rename": True,
66
+ },
67
+ {
68
+ "name": "vg/VG_100K",
69
+ "extract": True,
70
+ "extract_type": "directory",
71
+ "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip",
72
+ "do_rename": True,
73
+ },
74
+ {
75
+ "name": "vg/VG_100K_2",
76
+ "extract": True,
77
+ "extract_type": "directory",
78
+ "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip",
79
+ "do_rename": True,
80
+ },
81
+ ]
82
+ }
83
+ # fmt: on
84
+
85
+
86
+ def convert_to_jpg(image_dir: Path) -> None:
87
+ """Handling for OCR-VQA Images specifically; iterates through directory, converts all GIFs/PNGs."""
88
+ print(f"Converting all Images in `{image_dir}` to JPG")
89
+
90
+ for image_fn in tqdm(list(image_dir.iterdir())):
91
+ if image_fn.suffix in {".jpg", ".jpeg"} or (jpg_fn := image_dir / f"{image_fn.stem}.jpg").exists():
92
+ continue
93
+
94
+ if image_fn.suffix == ".gif":
95
+ gif = Image.open(image_fn)
96
+ gif.seek(0)
97
+ gif.convert("RGB").save(jpg_fn)
98
+ elif image_fn.suffix == ".png":
99
+ Image.open(image_fn).convert("RGB").save(jpg_fn)
100
+ else:
101
+ raise ValueError(f"Unexpected image format `{image_fn.suffix}`")
102
+
103
+
104
+ def download_with_progress(url: str, download_dir: Path, chunk_size_bytes: int = 1024) -> Path:
105
+ """Utility function for downloading files from the internet, with a handy Rich-based progress bar."""
106
+ print(f"Downloading {(dest_path := download_dir / Path(url).name)} from `{url}`", ctx_level=1)
107
+ if dest_path.exists():
108
+ return dest_path
109
+
110
+ # Otherwise --> fire an HTTP Request, with `stream = True`
111
+ response = requests.get(url, stream=True)
112
+
113
+ # Download w/ Transfer-Aware Progress
114
+ # => Reference: https://github.com/Textualize/rich/blob/master/examples/downloader.py
115
+ with Progress(
116
+ TextColumn("[bold]{task.description} - {task.fields[fname]}"),
117
+ BarColumn(bar_width=None),
118
+ "[progress.percentage]{task.percentage:>3.1f}%",
119
+ "•",
120
+ DownloadColumn(),
121
+ "•",
122
+ TransferSpeedColumn(),
123
+ transient=True,
124
+ ) as dl_progress:
125
+ dl_tid = dl_progress.add_task(
126
+ "Downloading", fname=dest_path.name, total=int(response.headers.get("content-length", "None"))
127
+ )
128
+ with open(dest_path, "wb") as f:
129
+ for data in response.iter_content(chunk_size=chunk_size_bytes):
130
+ dl_progress.advance(dl_tid, f.write(data))
131
+
132
+ return dest_path
133
+
134
+
135
+ def extract_with_progress(archive_path: Path, download_dir: Path, extract_type: str, cleanup: bool = False) -> Path:
136
+ """Utility function for extracting compressed archives, with a handy Rich-based progress bar."""
137
+ assert archive_path.suffix == ".zip", "Only `.zip` compressed archives are supported for now!"
138
+ print(f"Extracting {archive_path.name} to `{download_dir}`", ctx_level=1)
139
+
140
+ # Extract w/ Progress
141
+ with Progress(
142
+ TextColumn("[bold]{task.description} - {task.fields[aname]}"),
143
+ BarColumn(bar_width=None),
144
+ "[progress.percentage]{task.percentage:>3.1f}%",
145
+ "•",
146
+ MofNCompleteColumn(),
147
+ transient=True,
148
+ ) as ext_progress:
149
+ with ZipFile(archive_path) as zf:
150
+ ext_tid = ext_progress.add_task("Extracting", aname=archive_path.name, total=len(members := zf.infolist()))
151
+ extract_path = Path(zf.extract(members[0], download_dir))
152
+ if extract_type == "file":
153
+ assert len(members) == 1, f"Archive `{archive_path}` with extract type `{extract_type} has > 1 member!"
154
+ elif extract_type == "directory":
155
+ for member in members[1:]:
156
+ zf.extract(member, download_dir)
157
+ ext_progress.advance(ext_tid)
158
+ else:
159
+ raise ValueError(f"Extract type `{extract_type}` for archive `{archive_path}` is not defined!")
160
+
161
+ # Cleanup (if specified)
162
+ if cleanup:
163
+ archive_path.unlink()
164
+
165
+ return extract_path
166
+
167
+
168
+ def download_extract(dataset_id: str, root_dir: Path) -> None:
169
+ """Download all files for a given dataset (querying registry above), extracting archives if necessary."""
170
+ os.makedirs(download_dir := root_dir / "download" / dataset_id, exist_ok=True)
171
+
172
+ # Download Files => Single-Threaded, with Progress Bar
173
+ dl_tasks = [d for d in DATASET_REGISTRY[dataset_id] if not (download_dir / d["name"]).exists()]
174
+ for dl_task in dl_tasks:
175
+ dl_path = download_with_progress(dl_task["url"], download_dir)
176
+
177
+ # Extract Files (if specified) --> Note (assumes ".zip" ONLY!)
178
+ if dl_task["extract"]:
179
+ dl_path = extract_with_progress(dl_path, download_dir, dl_task["extract_type"])
180
+ dl_path = dl_path.parent if dl_path.is_file() else dl_path
181
+
182
+ # Rename Path --> dl_task["name"]
183
+ if dl_task["do_rename"]:
184
+ shutil.move(dl_path, download_dir / dl_task["name"])
185
+ if __name__ == "__main__":
186
+ import sys
187
+ from pathlib import Path
188
+
189
+ # 设置根目录
190
+ root_dir = Path("./data") # 这里设置一个默认的下载目录
191
+ os.makedirs(root_dir, exist_ok=True)
192
+
193
+ # 下载所有数据集
194
+ for dataset_id in DATASET_REGISTRY.keys():
195
+ print(f"开始下载数据集: {dataset_id}")
196
+ download_extract(dataset_id, root_dir)
197
+
198
+ print("所有数据集下载完成!")
deepspeed_pretrain.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import shutil
4
+ os.environ['HF_ENDPOINT']="https://hf-mirror.com"
5
+ from qwenva import tokenizer
6
+ from qwenva import processor
7
+ from qwenva import qwenva
8
+ images_file_path="/root/autodl-tmp/images"
9
+
10
+ import torch
11
+ from torch.utils.data import Dataset, DataLoader
12
+ import os
13
+ import json
14
+ from PIL import Image
15
+ import json
16
+ with open('/root/autodl-tmp/chat.json', 'r', encoding='utf-8') as f:
17
+ chat_data = json.load(f)
18
+ image_token=tokenizer.encode('<image>')[0]
19
+ pad_token=tokenizer.pad_token_id
20
+ image_token=tokenizer.encode('<image>')[0]
21
+ pad_token=tokenizer.pad_token_id
22
+ def process_data(sample,max_len=8012):
23
+ conversations=sample['conversations']
24
+ labels=[]
25
+ input_ids=[]
26
+ flag=0
27
+ messages=[]
28
+ input_ids=[]
29
+ for index,item in enumerate(conversations):
30
+ if item['from']=='human':
31
+ old_input_ids=input_ids
32
+ messages.append({'role':'user','content':item['value']})
33
+ input_ids=tokenizer.apply_chat_template(
34
+ messages,
35
+ add_generation_prompt=True
36
+ )
37
+ #input_ids+=input_token[]
38
+ labels+=[-100]*(len(input_ids)-len(old_input_ids))
39
+ if index==flag:
40
+ try:
41
+ image_index=input_ids.index(image_token)
42
+ labels[image_index]=image_token
43
+ except ValueError:
44
+ print("image token not found")
45
+ flag=index+1
46
+ continue
47
+ elif item['from']=='gpt':
48
+ old_input_ids=input_ids
49
+ messages.append({'role':'assistant','content':item['value']})
50
+ input_ids=tokenizer.apply_chat_template(
51
+ messages
52
+ )
53
+ flag=index+1
54
+ labels+=input_ids[len(old_input_ids):]
55
+ #填充或者截断,使得长度相同
56
+ if len(input_ids)>max_len:
57
+ input_ids=input_ids[:max_len]
58
+ labels=labels[:max_len]
59
+ attention_mask=[1]*len(input_ids)
60
+ else:
61
+ attention_mask=[1]*len(input_ids)+[0]*(max_len-len(input_ids))
62
+ input_ids+=[pad_token]*(max_len-len(input_ids))
63
+ labels+=[-100]*(max_len-len(labels))
64
+ #转化为张量
65
+ input_ids=torch.tensor(input_ids)
66
+ attention_mask=torch.tensor(attention_mask)
67
+ labels=torch.tensor(labels)
68
+ image_index=torch.tensor(image_index)
69
+ return {
70
+ 'input_ids':input_ids,
71
+ 'attention_mask':attention_mask,
72
+ 'labels':labels,
73
+ 'image_idx':image_index
74
+ }
75
+
76
+
77
+ import os
78
+ import torch
79
+ from torch.utils.data import Dataset
80
+ from PIL import Image
81
+ class MyDataset(Dataset):
82
+ def __init__(self, images_file_path,data,max_len=1024):
83
+ self.max_len=max_len
84
+ self.images_file_path = images_file_path
85
+ self.data = data
86
+ self.max_len=max_len
87
+ def __len__(self):
88
+ return len(self.data)
89
+ def __getitem__(self, index):
90
+ output_=process_data(self.data[index],max_len=self.max_len)
91
+ img_path=os.path.join(self.images_file_path,self.data[index]['image'])
92
+ img=Image.open(img_path)
93
+ input_pixel= processor(images=img, return_tensors="pt")
94
+ output_['input_pixel']=input_pixel['pixel_values'].squeeze()
95
+ return output_
96
+ dataset=MyDataset(images_file_path,chat_data,max_len=360)
97
+ train_loader=DataLoader(dataset,batch_size=8,shuffle=True)
98
+ import deepspeed
99
+ import argparse
100
+ # 设置设备
101
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
102
+ qwenva=qwenva.to(device)
103
+ model_engine,optimizer,_,_=deepspeed.initialize(
104
+ model=qwenva,
105
+ args=argparse.Namespace(),
106
+ model_parameters=qwenva.parameters(),
107
+ config_params="./deepspeed_config.json"
108
+ )
109
+ #checkpoint_path = "/root/autodl-tmp/best_model_2"
110
+ #model_engine.load_checkpoint(checkpoint_path)
111
+ import torch.optim as optim
112
+ import torch.nn as nn
113
+ from torch.amp import autocast, GradScaler
114
+ #optimizer = optim.Adam(model.parameters(), lr=0.001)
115
+ loss_fn = nn.CrossEntropyLoss()
116
+ #eps = 1e-8
117
+ accumulation_steps = 2
118
+ # 训练函数
119
+ def train(model_engine, train_dataloader, optimizer, loss_fn, device, epochs):
120
+ model_engine.train()
121
+ #model_engine.to(device)
122
+ for epoch in range(epochs):
123
+ # 使用 tqdm 显示进度条
124
+ with tqdm(total=len(train_dataloader), desc=f'Epoch {epoch + 1}/{epochs}', unit='batch') as pbar:
125
+ optimizer.zero_grad()
126
+ for batch_idx, batch in enumerate(train_dataloader):
127
+ # 将数据拷贝到 GPU 上
128
+ input_ids = batch['input_ids'].to(device)
129
+ attention_mask = batch['attention_mask'].to(device)
130
+ input_pixel = batch['input_pixel'].to(device)
131
+ labels = batch['labels'].to(device)
132
+ image_idx=batch['image_idx'].to(device)
133
+ logits = model_engine(input_ids, attention_mask, input_pixel,image_idx)
134
+ # 计算损失
135
+ #max_logits= logits.max(dim=-1, keepdim=True)[0] # 计算最大值
136
+ #stable_logits= logits - max_logits # 减去最大值得到数值稳定的值
137
+ loss= loss_fn(logits[:, :-1, :].reshape(-1, logits.shape[-1]), labels[:, 1:].reshape(-1).clone())
138
+ # 反向传播
139
+ model_engine.backward(loss)
140
+ if (batch_idx+1)%accumulation_steps==0:
141
+ model_engine.step()
142
+ pbar.update(1)
143
+ pbar.set_postfix(loss=loss.item()) # 显示当前损失
144
+ if (batch_idx+1)%24807==0:
145
+ # 如果文件夹存在,则删除并重新创建
146
+ if os.path.exists("/root/autodl-tmp/best_model_instruct"):
147
+ shutil.rmtree("/root/autodl-tmp/best_model_instruct") # 删除文件夹及其内容
148
+ os.makedirs("/root/autodl-tmp/best_model_instruct") # 重新创建文件夹
149
+ model_engine.save_checkpoint("/root/autodl-tmp/best_model_instruct")
150
+ print(f" model saved at batch {batch_idx+1}")
151
+ train(model_engine, train_loader, optimizer, loss_fn, device, epochs=1)
152
+
deepspeed_train_150k.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import deepspeed
3
+ #deepspeed.initialize(config="./deepspeed_config.json",log_level='DEBUG')
4
+ from tqdm import tqdm
5
+ import shutil
6
+ os.environ['HF_ENDPOINT']="https://hf-mirror.com"
7
+ from qwenva import tokenizer
8
+ from qwenva import processor
9
+ from qwenva import qwenva
10
+ images_file_path='./data/download/llava-v1.5-instruct/coco/train2017'
11
+
12
+ import torch
13
+ from torch.utils.data import Dataset, DataLoader
14
+ import os
15
+ import json
16
+ from PIL import Image
17
+ import json
18
+ with open('/root/autodl-tmp/LLaVA-Instruct-150K/llava_instruct_150k.json', 'r', encoding='utf-8') as f:
19
+ chat_data = json.load(f)
20
+ import torch
21
+ image_token=tokenizer.encode('<image>')[0]
22
+ pad_token=tokenizer.pad_token_id
23
+ image_token=tokenizer.encode('<image>')[0]
24
+ pad_token=tokenizer.pad_token_id
25
+ def process_data(sample,max_len=8012):
26
+ conversations=sample['conversations']
27
+ labels=[]
28
+ input_ids=[]
29
+ flag=0
30
+ messages=[]
31
+ input_ids=[]
32
+ try:
33
+ for index,item in enumerate(conversations):
34
+ if item['from']=='human':
35
+ old_input_ids=input_ids
36
+ messages.append({'role':'user','content':item['value']})
37
+ input_ids=tokenizer.apply_chat_template(
38
+ messages,
39
+ add_generation_prompt=True
40
+ )
41
+ #input_ids+=input_token[]
42
+ labels+=[-100]*(len(input_ids)-len(old_input_ids))
43
+ if index==flag:
44
+ try:
45
+ image_index=input_ids.index(image_token)
46
+ labels[image_index]=image_token
47
+ except ValueError:
48
+ print("image token not found")
49
+ flag=index+1
50
+ continue
51
+ elif item['from']=='gpt':
52
+ old_input_ids=input_ids
53
+ messages.append({'role':'assistant','content':item['value']})
54
+ input_ids=tokenizer.apply_chat_template(
55
+ messages
56
+ )
57
+ labels+=input_ids[len(old_input_ids):]
58
+ except:
59
+ print("error in process_data_1")
60
+ exit()
61
+ #填充或者截断,使得长度相同
62
+ try:
63
+ if len(input_ids)>max_len:
64
+ input_ids=input_ids[:max_len]
65
+ labels=labels[:max_len]
66
+ attention_mask=[1]*len(input_ids)
67
+ else:
68
+ attention_mask=[1]*len(input_ids)+[0]*(max_len-len(input_ids))
69
+ input_ids+=[pad_token]*(max_len-len(input_ids))
70
+ labels+=[-100]*(max_len-len(labels))
71
+ except:
72
+ print("error in process_data_2")
73
+ exit()
74
+ #转化为张量
75
+ try:
76
+ input_ids=torch.tensor(input_ids)
77
+ attention_mask=torch.tensor(attention_mask)
78
+ labels=torch.tensor(labels)
79
+ image_index=torch.tensor(image_index)
80
+ except:
81
+ print("error in tensor")
82
+ exit()
83
+ return {
84
+ 'input_ids':input_ids,
85
+ 'attention_mask':attention_mask,
86
+ 'labels':labels,
87
+ 'image_idx':image_index
88
+ }
89
+
90
+
91
+ import os
92
+ import torch
93
+ from torch.utils.data import Dataset
94
+ from PIL import Image
95
+ class MyDataset(Dataset):
96
+ def __init__(self, images_file_path,data,max_len=1024):
97
+ self.max_len=max_len
98
+ self.images_file_path = images_file_path
99
+ self.data = data
100
+ self.max_len=max_len
101
+ def __len__(self):
102
+ return len(self.data)
103
+ def __getitem__(self, index):
104
+ output_=process_data(self.data[index],max_len=self.max_len)
105
+ img_path=os.path.join(self.images_file_path,self.data[index]['image'])
106
+ try:
107
+ img=Image.open(img_path)
108
+ except:
109
+ print(f"image {img_path} not found")
110
+ output_['labels']=torch.tensor([-100]*self.max_len)
111
+ input_pixel= processor(images=img, return_tensors="pt")
112
+ output_['input_pixel']=input_pixel['pixel_values'].squeeze()
113
+ return output_
114
+
115
+
116
+
117
+ dataset=MyDataset(images_file_path,chat_data,max_len=2048)
118
+ train_loader=DataLoader(dataset,batch_size=8,shuffle=True)
119
+ import argparse
120
+ # 设置设备
121
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
122
+ qwenva=qwenva.to(device)
123
+ model_engine,optimizer,_,_=deepspeed.initialize(
124
+ model=qwenva,
125
+ args=argparse.Namespace(),
126
+ model_parameters=qwenva.parameters(),
127
+ config_params="./deepspeed_config.json"
128
+ )
129
+ #checkpoint_path = "./best_model_2"
130
+ #model_engine.load_checkpoint(checkpoint_path)
131
+ #保存编译模型权重
132
+ #torch.save(model_engine.module.state_dict(), "./compiled_model.pth")
133
+ for name, param in model_engine.module._orig_mod.text_embedding.named_parameters():
134
+ param.requires_grad = True
135
+ #print("embedding权重梯度打开:",name)
136
+ #for name, param in model_engine.module._orig_mod.align_layer.named_parameters():
137
+ #param.requires_grad = True
138
+ #print("align_layer权重梯度打开",name)
139
+
140
+ for name,param in model_engine.module._orig_mod.lm_head.named_parameters():
141
+ param.requires_grad = True
142
+ #print("lm_head权重梯度打开",name)
143
+
144
+ for name,param in model_engine.module._orig_mod.transformer.named_parameters():
145
+ param.requires_grad = True
146
+ #print("transformer权重梯度打开",name)
147
+ for name,param in model_engine.module._orig_mod.named_parameters():
148
+ if param.requires_grad:
149
+ print(f"Layer: {name}, Requires Grad: {param.requires_grad}")
150
+
151
+
152
+ #optimizer = optim.Adam(model.parameters(), lr=0.001)
153
+ import torch.nn as nn
154
+ loss_fn = nn.CrossEntropyLoss()
155
+ #eps = 1e-8
156
+ accumulation_steps = 1
157
+ # 训练函数
158
+ def train(model_engine, train_dataloader, loss_fn, device, epochs):
159
+ model_engine.train()
160
+ #model_engine.to(device)
161
+ for epoch in range(epochs):
162
+ # 使用 tqdm 显示进度条
163
+ with tqdm(total=len(train_dataloader), desc=f'Epoch {epoch + 1}/{epochs}', unit='batch') as pbar:
164
+ #optimizer.zero_grad()
165
+ try:
166
+ for batch_idx, batch in enumerate(train_dataloader):
167
+ # 将数据拷贝到 GPU 上
168
+ input_ids = batch['input_ids'].to(device)
169
+ attention_mask = batch['attention_mask'].to(device)
170
+ input_pixel = batch['input_pixel'].to(device)
171
+ labels = batch['labels'].to(device)
172
+ image_idx=batch['image_idx'].to(device)
173
+ logits = model_engine(input_ids, attention_mask, input_pixel,image_idx)
174
+ # 计算损失
175
+ max_logits= logits.max(dim=-1, keepdim=True)[0] # 计算最大值
176
+ stable_logits= logits - max_logits # 减去最大值得到数值稳定的值
177
+ loss= loss_fn(stable_logits[:, :-1, :].reshape(-1, stable_logits.shape[-1]), labels[:, 1:].reshape(-1).clone())
178
+ model_engine.backward(loss)
179
+ if (batch_idx+1)%accumulation_steps==0:
180
+ model_engine.step()
181
+ pbar.update(1)
182
+ pbar.set_postfix(loss=loss.item()) # 显示当前损失
183
+ if (batch_idx+1)%6000==0:
184
+ # 如果文件夹存在,则删除并重新创建
185
+ if os.path.exists("./best_model_2"):
186
+ shutil.rmtree("./best_model_2") # 删除文件夹及其内容
187
+ os.makedirs("./best_model_2") # 重新创建文件夹
188
+ model_engine.save_checkpoint("./best_model_2")
189
+ torch.save(model_engine.module.state_dict(), "./compiled_model_2.pth")
190
+ print(f" model saved at batch {batch_idx+1}")
191
+ except Exception as e:
192
+ print(f"error in train {e}")
193
+
194
+ train(model_engine, train_loader, loss_fn, device, epochs=2)
195
+
deepspeed_train_665k.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import deepspeed
3
+ #deepspeed.initialize(config="./deepspeed_config.json",log_level='DEBUG')
4
+ from tqdm import tqdm
5
+ import shutil
6
+ os.environ['HF_ENDPOINT']="https://hf-mirror.com"
7
+ from qwenva import tokenizer
8
+ from qwenva import processor
9
+ from qwenva import qwenva
10
+ images_file_path='./data/download/llava-v1.5-instruct'
11
+
12
+ import torch
13
+ from torch.utils.data import Dataset, DataLoader
14
+ import os
15
+ import json
16
+ from PIL import Image
17
+ import json
18
+ with open('/root/autodl-tmp/LLaVA-Instruct-150K/qwenva_mix665k.json', 'r', encoding='utf-8') as f:
19
+ chat_data = json.load(f)
20
+ import torch
21
+ image_token=tokenizer.encode('<image>')[0]
22
+ pad_token=tokenizer.pad_token_id
23
+ image_token=tokenizer.encode('<image>')[0]
24
+ pad_token=tokenizer.pad_token_id
25
+ def process_data(sample,max_len=8012):
26
+ conversations=sample['conversations']
27
+ labels=[]
28
+ input_ids=[]
29
+ flag=0
30
+ messages=[]
31
+ input_ids=[]
32
+ try:
33
+ for index,item in enumerate(conversations):
34
+ if item['from']=='human':
35
+ old_input_ids=input_ids
36
+ messages.append({'role':'user','content':item['value']})
37
+ input_ids=tokenizer.apply_chat_template(
38
+ messages,
39
+ add_generation_prompt=True
40
+ )
41
+ #input_ids+=input_token[]
42
+ labels+=[-100]*(len(input_ids)-len(old_input_ids))
43
+ if index==flag:
44
+ if image_token in input_ids:
45
+ image_index=input_ids.index(image_token)
46
+ labels[image_index]=image_token
47
+ else:
48
+ image_index=-100
49
+ elif item['from']=='gpt':
50
+ old_input_ids=input_ids
51
+ messages.append({'role':'assistant','content':item['value']})
52
+ input_ids=tokenizer.apply_chat_template(
53
+ messages
54
+ )
55
+ labels+=input_ids[len(old_input_ids):]
56
+ except:
57
+ print("error in process_data_1")
58
+ exit()
59
+ #填充或者截断,使得长度相同
60
+ try:
61
+ if len(input_ids)>max_len:
62
+ input_ids=input_ids[:max_len]
63
+ labels=labels[:max_len]
64
+ attention_mask=[1]*len(input_ids)
65
+ else:
66
+ attention_mask=[1]*len(input_ids)+[0]*(max_len-len(input_ids))
67
+ input_ids+=[pad_token]*(max_len-len(input_ids))
68
+ labels+=[-100]*(max_len-len(labels))
69
+ except:
70
+ print("error in process_data_2")
71
+ exit()
72
+ #转化为张量
73
+ try:
74
+ input_ids=torch.tensor(input_ids)
75
+ attention_mask=torch.tensor(attention_mask)
76
+ labels=torch.tensor(labels)
77
+ image_index=torch.tensor(image_index)
78
+ except:
79
+ print("error in tensor")
80
+ exit()
81
+ return {
82
+ 'input_ids':input_ids,
83
+ 'attention_mask':attention_mask,
84
+ 'labels':labels,
85
+ 'image_idx':image_index
86
+ }
87
+
88
+
89
+ import os
90
+ import torch
91
+ from torch.utils.data import Dataset
92
+ from PIL import Image
93
+ class MyDataset(Dataset):
94
+ def __init__(self, images_file_path,data,max_len=1024):
95
+ self.max_len=max_len
96
+ self.images_file_path = images_file_path
97
+ self.data = data
98
+ self.max_len=max_len
99
+ def __len__(self):
100
+ return len(self.data)
101
+ def __getitem__(self, index):
102
+ output_=process_data(self.data[index],max_len=self.max_len)
103
+ if output_['image_idx']!=-100:
104
+ img_path=os.path.join(self.images_file_path,self.data[index]['image'])
105
+ img=Image.open(img_path)
106
+ input_pixel= processor(images=img, return_tensors="pt")
107
+ output_['input_pixel']=input_pixel['pixel_values'].squeeze()
108
+ else:
109
+ output_['input_pixel']=torch.zeros(3,224,224).to(device=output_['input_ids'].device,dtype=output_['input_ids'].dtype)
110
+ return output_
111
+
112
+
113
+
114
+ dataset=MyDataset(images_file_path,chat_data,max_len=2048)
115
+ train_loader=DataLoader(dataset,batch_size=16,shuffle=True)
116
+ import argparse
117
+ model_engine,optimizer,_,_=deepspeed.initialize(
118
+ model=qwenva,
119
+ args=argparse.Namespace(),
120
+ model_parameters=qwenva.parameters(),
121
+ config_params="./deepspeed_config.json"
122
+ )
123
+ #checkpoint_path = "./best_model_2"
124
+ #model_engine.load_checkpoint(checkpoint_path)
125
+ #保存编译模型权重
126
+ #torch.save(model_engine.module.state_dict(), "./compiled_model.pth")
127
+ for name, param in model_engine.module._orig_mod.text_embedding.named_parameters():
128
+ param.requires_grad = True
129
+ #print("embedding权重梯度打开:",name)
130
+ #for name, param in model_engine.module._orig_mod.align_layer.named_parameters():
131
+ #param.requires_grad = True
132
+ #print("align_layer权重梯度打开",name)
133
+
134
+ for name,param in model_engine.module._orig_mod.lm_head.named_parameters():
135
+ param.requires_grad = True
136
+ #print("lm_head权重梯度打开",name)
137
+
138
+ for name,param in model_engine.module._orig_mod.transformer.named_parameters():
139
+ param.requires_grad = True
140
+ #print("transformer权重梯度打开",name)
141
+ for name,param in model_engine.module._orig_mod.named_parameters():
142
+ if param.requires_grad:
143
+ print(f"Layer: {name}, Requires Grad: {param.requires_grad}")
144
+
145
+
146
+ #optimizer = optim.Adam(model.parameters(), lr=0.001)
147
+ import torch.nn as nn
148
+ loss_fn = nn.CrossEntropyLoss()
149
+ #eps = 1e-8
150
+ accumulation_steps = 1
151
+ # 训练函数
152
+ def train(model_engine, train_dataloader, loss_fn, device, epochs):
153
+ model_engine.train()
154
+ #model_engine.to(device)
155
+ for epoch in range(epochs):
156
+ # 使用 tqdm 显示进度条
157
+ with tqdm(total=len(train_dataloader), desc=f'Epoch {epoch + 1}/{epochs}', unit='batch') as pbar:
158
+ #optimizer.zero_grad()
159
+ try:
160
+ for batch_idx, batch in enumerate(train_dataloader):
161
+ # 将数据拷贝到 GPU 上
162
+ input_ids = batch['input_ids'].to(device)
163
+ attention_mask = batch['attention_mask'].to(device)
164
+ input_pixel = batch['input_pixel'].to(device)
165
+ labels = batch['labels'].to(device)
166
+ image_idx=batch['image_idx'].to(device)
167
+ logits = model_engine(input_ids, attention_mask, input_pixel,image_idx)
168
+ # 计算损失
169
+ #max_logits= logits.max(dim=-1, keepdim=True)[0] # 计算最大值
170
+ #stable_logits= logits - max_logits # 减去最大值得到数值稳定的值
171
+ loss= loss_fn(logits[:, :-1, :].reshape(-1, logits.shape[-1]), labels[:, 1:].reshape(-1).clone())
172
+ model_engine.backward(loss)
173
+ if (batch_idx+1)%accumulation_steps==0:
174
+ model_engine.step()
175
+ pbar.update(1)
176
+ pbar.set_postfix(loss=loss.item()) # 显示当前损失
177
+ if (batch_idx+1)%4100==0:
178
+ # 如果文件夹存在,则删除并重新创建
179
+ if os.path.exists("./best_model_2"):
180
+ shutil.rmtree("./best_model_2") # 删除文件夹及其内容
181
+ os.makedirs("./best_model_2") # 重新创建文件夹
182
+ model_engine.save_checkpoint("./best_model_2")
183
+ torch.save(model_engine.module.state_dict(), "./compiled_model_3.pth")
184
+ print(f" model saved at batch {batch_idx+1}")
185
+ except Exception as e:
186
+ print(f"error in train {e}")
187
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
188
+ train(model_engine, train_loader, loss_fn, device, epochs=2)
189
+
download2.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from pathlib import Path
4
+ from typing import Dict, List, TypedDict
5
+ from zipfile import ZipFile
6
+
7
+ import requests
8
+ from PIL import Image
9
+ from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn
10
+ from tqdm import tqdm
11
+ """
12
+ {
13
+ "name": "coco/train2017", # Visual Instruct Tuning images are all sourced from COCO Train 2017
14
+ "extract": True,
15
+ "extract_type": "directory",
16
+ "url": "http://images.cocodataset.org/zips/train2017.zip",
17
+ "do_rename": True,
18
+ },
19
+ {
20
+ "name": "gqa/images",
21
+ "extract": True,
22
+ "extract_type": "directory",
23
+ "url": "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip",
24
+ "do_rename": True,
25
+ },
26
+ {
27
+ "name": "ocr_vqa/images",
28
+ "extract": True,
29
+ "extract_type": "directory",
30
+ "url": "https://hf-mirror.com/datasets/qnguyen3/ocr_vqa/resolve/main/ocr_vqa.zip",
31
+ "do_rename": True,
32
+ },
33
+ {
34
+ "name": "textvqa/train_images",
35
+ "extract": True,
36
+ "extract_type": "directory",
37
+ "url": "https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip",
38
+ "do_rename": True,
39
+ },
40
+ {
41
+ "name": "vg/VG_100K_2",
42
+ "extract": True,
43
+ "extract_type": "directory",
44
+ "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip",
45
+ "do_rename": True,
46
+ },
47
+ """
48
+
49
+ # === Dataset Registry w/ Links ===
50
+ # fmt: off
51
+ DatasetComponent = TypedDict(
52
+ "DatasetComponent",
53
+ {"name": str, "extract": bool, "extract_type": str, "url": str, "do_rename": bool},
54
+ total=False
55
+ )
56
+
57
+ DATASET_REGISTRY: Dict[str, List[DatasetComponent]] = {
58
+ # === LLaVa v1.5 Dataset(s) ===
59
+
60
+ "llava-v1.5-instruct":[
61
+
62
+ {
63
+ "name": "vg/VG_100K",
64
+ "extract": True,
65
+ "extract_type": "directory",
66
+ "url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip",
67
+ "do_rename": True,
68
+ }
69
+ ]
70
+
71
+
72
+
73
+ }
74
+ # fmt: on
75
+
76
+
77
+ def convert_to_jpg(image_dir: Path) -> None:
78
+ """Handling for OCR-VQA Images specifically; iterates through directory, converts all GIFs/PNGs."""
79
+ print(f"Converting all Images in `{image_dir}` to JPG")
80
+
81
+ for image_fn in tqdm(list(image_dir.iterdir())):
82
+ jpg_fn = image_dir / f"{image_fn.stem}.jpg" # 创建 JPG 文件名
83
+ if image_fn.suffix in {".jpg", ".jpeg"} or jpg_fn.exists():
84
+ continue
85
+
86
+ if image_fn.suffix == ".gif":
87
+ gif = Image.open(image_fn)
88
+ gif.seek(0)
89
+ gif.convert("RGB").save(jpg_fn)
90
+ elif image_fn.suffix == ".png":
91
+ Image.open(image_fn).convert("RGB").save(jpg_fn)
92
+ else:
93
+ raise ValueError(f"Unexpected image format `{image_fn.suffix}`")
94
+
95
+
96
+ import os
97
+ import shutil
98
+ from pathlib import Path
99
+ from typing import Dict, List, TypedDict
100
+ from zipfile import ZipFile
101
+
102
+ import requests
103
+ from PIL import Image
104
+ from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn
105
+ from tqdm import tqdm
106
+
107
+ # DatasetComponent 和 DATASET_REGISTRY 保持不变
108
+
109
+ def download_with_progress(url: str, download_dir: Path, chunk_size_bytes: int = 1024) -> Path:
110
+ """Utility function for downloading files from the internet, with a handy Rich-based progress bar."""
111
+ print(f"Downloading {url}")
112
+
113
+ dest_path = download_dir / Path(url).name
114
+ resume_header = {}
115
+
116
+ if dest_path.exists():
117
+ return dest_path
118
+
119
+ max_retries = 5
120
+ for attempt in range(max_retries):
121
+ try:
122
+ response = requests.get(url, headers=resume_header, stream=True)
123
+ if response.status_code not in (200, 206):
124
+ raise Exception(f"Failed to download. Status code: {response.status_code}")
125
+
126
+ # 下载进度条
127
+ with Progress(
128
+ TextColumn("[bold]{task.description} - {task.fields[fname]}"),
129
+ BarColumn(bar_width=None),
130
+ "[progress.percentage]{task.percentage:>3.1f}%",
131
+ "•",
132
+ DownloadColumn(),
133
+ "•",
134
+ TransferSpeedColumn(),
135
+ transient=True,
136
+ ) as dl_progress:
137
+ dl_tid = dl_progress.add_task(
138
+ "Downloading", fname=dest_path.name, total=int(response.headers.get("content-length", "None"))
139
+ )
140
+ with open(dest_path, "ab") as f: # 以二进制追加模式打开文件
141
+ for data in response.iter_content(chunk_size=chunk_size_bytes):
142
+ f.write(data)
143
+ dl_progress.advance(dl_tid, chunk_size_bytes)
144
+
145
+ return dest_path
146
+
147
+ except Exception as e:
148
+ print(f"Attempt {attempt + 1}/{max_retries} failed: {e}")
149
+ if attempt < max_retries - 1:
150
+ print("Retrying...")
151
+ else:
152
+ raise
153
+
154
+ # 其他函数保持不变,main 方法也不变
155
+
156
+
157
+
158
+ def extract_with_progress(archive_path: Path, download_dir: Path, extract_type: str, cleanup: bool = False) -> Path:
159
+ """Utility function for extracting compressed archives, with a handy Rich-based progress bar."""
160
+ assert archive_path.suffix == ".zip", "Only `.zip` compressed archives are supported for now!"
161
+ print(f"Extracting {archive_path.name} to `{download_dir}`")
162
+
163
+ with Progress(
164
+ TextColumn("[bold]{task.description} - {task.fields[aname]}"),
165
+ BarColumn(bar_width=None),
166
+ "[progress.percentage]{task.percentage:>3.1f}%",
167
+ "•",
168
+ MofNCompleteColumn(),
169
+ transient=True,
170
+ ) as ext_progress:
171
+ with ZipFile(archive_path) as zf:
172
+ ext_tid = ext_progress.add_task("Extracting", aname=archive_path.name, total=len(members := zf.infolist()))
173
+ extract_path = Path(zf.extract(members[0], download_dir))
174
+ if extract_type == "file":
175
+ assert len(members) == 1, f"Archive `{archive_path}` with extract type `{extract_type} has > 1 member!"
176
+ elif extract_type == "directory":
177
+ for member in members[1:]:
178
+ zf.extract(member, download_dir)
179
+ ext_progress.advance(ext_tid)
180
+ else:
181
+ raise ValueError(f"Extract type `{extract_type}` for archive `{archive_path}` is not defined!")
182
+
183
+ if cleanup:
184
+ archive_path.unlink()
185
+
186
+ return extract_path
187
+
188
+
189
+ def download_extract(dataset_id: str, root_dir: Path) -> None:
190
+ """Download all files for a given dataset (querying registry above), extracting archives if necessary."""
191
+ os.makedirs(download_dir := root_dir / "download" / dataset_id, exist_ok=True)
192
+
193
+ # Download Files
194
+ dl_tasks = [d for d in DATASET_REGISTRY[dataset_id] if not (download_dir / d["name"]).exists()]
195
+ for dl_task in dl_tasks:
196
+ dl_path = download_with_progress(dl_task["url"], download_dir)
197
+
198
+ if dl_task["extract"]:
199
+ dl_path = extract_with_progress(dl_path, download_dir, dl_task["extract_type"])
200
+ dl_path = dl_path.parent if dl_path.is_file() else dl_path
201
+
202
+ if dl_task["do_rename"]:
203
+ shutil.move(dl_path, download_dir / dl_task["name"])
204
+ if __name__ == "__main__":
205
+ import sys
206
+ from pathlib import Path
207
+
208
+ # 设置根目录
209
+ root_dir = Path("./data") # 这里设置一个默认的下载目录
210
+ os.makedirs(root_dir, exist_ok=True)
211
+
212
+ # 下载所有数据集
213
+ for dataset_id in DATASET_REGISTRY.keys():
214
+ print(f"开始下载数据集: {dataset_id}")
215
+ download_extract(dataset_id, root_dir)
216
+
217
+ print("所有数据集下载完成!")