jiangchengchengNLP
commited on
Upload 5 files
Browse filesfiles for train
- data_download.py +198 -0
- deepspeed_pretrain.py +152 -0
- deepspeed_train_150k.py +195 -0
- deepspeed_train_665k.py +189 -0
- download2.py +217 -0
data_download.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
download.py
|
3 |
+
|
4 |
+
Utility functions for downloading and extracting various datasets to (local) disk.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import shutil
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Dict, List, TypedDict
|
11 |
+
from zipfile import ZipFile
|
12 |
+
|
13 |
+
import requests
|
14 |
+
from PIL import Image
|
15 |
+
from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
#from prismatic.overwatch import initialize_overwatch
|
19 |
+
|
20 |
+
# Initialize Overwatch =>> Wraps `logging.Logger`
|
21 |
+
#overwatch = initialize_overwatch(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
# === Dataset Registry w/ Links ===
|
25 |
+
# fmt: off
|
26 |
+
DatasetComponent = TypedDict(
|
27 |
+
"DatasetComponent",
|
28 |
+
{"name": str, "extract": bool, "extract_type": str, "url": str, "do_rename": bool},
|
29 |
+
total=False
|
30 |
+
)
|
31 |
+
|
32 |
+
DATASET_REGISTRY: Dict[str, List[DatasetComponent]] = {
|
33 |
+
# === LLaVa v1.5 Dataset(s) ===
|
34 |
+
|
35 |
+
# Note =>> This is the full suite of datasets included in the LLaVa 1.5 "finetuning" stage; all the LLaVa v1.5
|
36 |
+
# models are finetuned on this split. We use this dataset for all experiments in our paper.
|
37 |
+
"llava-v1.5-instruct":
|
38 |
+
[
|
39 |
+
{
|
40 |
+
"name": "coco/train2017", # Visual Instruct Tuning images are all sourced from COCO Train 2017
|
41 |
+
"extract": True,
|
42 |
+
"extract_type": "directory",
|
43 |
+
"url": "http://images.cocodataset.org/zips/train2017.zip",
|
44 |
+
"do_rename": True,
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"name": "gqa/images",
|
48 |
+
"extract": True,
|
49 |
+
"extract_type": "directory",
|
50 |
+
"url": "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip",
|
51 |
+
"do_rename": True,
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"name": "ocr_vqa/images",
|
55 |
+
"extract": True,
|
56 |
+
"extract_type": "directory",
|
57 |
+
"url": "https://hf-mirror.com/datasets/qnguyen3/ocr_vqa/resolve/main/ocr_vqa.zip",
|
58 |
+
"do_rename": True,
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"name": "textvqa/train_images",
|
62 |
+
"extract": True,
|
63 |
+
"extract_type": "directory",
|
64 |
+
"url": "https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip",
|
65 |
+
"do_rename": True,
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"name": "vg/VG_100K",
|
69 |
+
"extract": True,
|
70 |
+
"extract_type": "directory",
|
71 |
+
"url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip",
|
72 |
+
"do_rename": True,
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"name": "vg/VG_100K_2",
|
76 |
+
"extract": True,
|
77 |
+
"extract_type": "directory",
|
78 |
+
"url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip",
|
79 |
+
"do_rename": True,
|
80 |
+
},
|
81 |
+
]
|
82 |
+
}
|
83 |
+
# fmt: on
|
84 |
+
|
85 |
+
|
86 |
+
def convert_to_jpg(image_dir: Path) -> None:
|
87 |
+
"""Handling for OCR-VQA Images specifically; iterates through directory, converts all GIFs/PNGs."""
|
88 |
+
print(f"Converting all Images in `{image_dir}` to JPG")
|
89 |
+
|
90 |
+
for image_fn in tqdm(list(image_dir.iterdir())):
|
91 |
+
if image_fn.suffix in {".jpg", ".jpeg"} or (jpg_fn := image_dir / f"{image_fn.stem}.jpg").exists():
|
92 |
+
continue
|
93 |
+
|
94 |
+
if image_fn.suffix == ".gif":
|
95 |
+
gif = Image.open(image_fn)
|
96 |
+
gif.seek(0)
|
97 |
+
gif.convert("RGB").save(jpg_fn)
|
98 |
+
elif image_fn.suffix == ".png":
|
99 |
+
Image.open(image_fn).convert("RGB").save(jpg_fn)
|
100 |
+
else:
|
101 |
+
raise ValueError(f"Unexpected image format `{image_fn.suffix}`")
|
102 |
+
|
103 |
+
|
104 |
+
def download_with_progress(url: str, download_dir: Path, chunk_size_bytes: int = 1024) -> Path:
|
105 |
+
"""Utility function for downloading files from the internet, with a handy Rich-based progress bar."""
|
106 |
+
print(f"Downloading {(dest_path := download_dir / Path(url).name)} from `{url}`", ctx_level=1)
|
107 |
+
if dest_path.exists():
|
108 |
+
return dest_path
|
109 |
+
|
110 |
+
# Otherwise --> fire an HTTP Request, with `stream = True`
|
111 |
+
response = requests.get(url, stream=True)
|
112 |
+
|
113 |
+
# Download w/ Transfer-Aware Progress
|
114 |
+
# => Reference: https://github.com/Textualize/rich/blob/master/examples/downloader.py
|
115 |
+
with Progress(
|
116 |
+
TextColumn("[bold]{task.description} - {task.fields[fname]}"),
|
117 |
+
BarColumn(bar_width=None),
|
118 |
+
"[progress.percentage]{task.percentage:>3.1f}%",
|
119 |
+
"•",
|
120 |
+
DownloadColumn(),
|
121 |
+
"•",
|
122 |
+
TransferSpeedColumn(),
|
123 |
+
transient=True,
|
124 |
+
) as dl_progress:
|
125 |
+
dl_tid = dl_progress.add_task(
|
126 |
+
"Downloading", fname=dest_path.name, total=int(response.headers.get("content-length", "None"))
|
127 |
+
)
|
128 |
+
with open(dest_path, "wb") as f:
|
129 |
+
for data in response.iter_content(chunk_size=chunk_size_bytes):
|
130 |
+
dl_progress.advance(dl_tid, f.write(data))
|
131 |
+
|
132 |
+
return dest_path
|
133 |
+
|
134 |
+
|
135 |
+
def extract_with_progress(archive_path: Path, download_dir: Path, extract_type: str, cleanup: bool = False) -> Path:
|
136 |
+
"""Utility function for extracting compressed archives, with a handy Rich-based progress bar."""
|
137 |
+
assert archive_path.suffix == ".zip", "Only `.zip` compressed archives are supported for now!"
|
138 |
+
print(f"Extracting {archive_path.name} to `{download_dir}`", ctx_level=1)
|
139 |
+
|
140 |
+
# Extract w/ Progress
|
141 |
+
with Progress(
|
142 |
+
TextColumn("[bold]{task.description} - {task.fields[aname]}"),
|
143 |
+
BarColumn(bar_width=None),
|
144 |
+
"[progress.percentage]{task.percentage:>3.1f}%",
|
145 |
+
"•",
|
146 |
+
MofNCompleteColumn(),
|
147 |
+
transient=True,
|
148 |
+
) as ext_progress:
|
149 |
+
with ZipFile(archive_path) as zf:
|
150 |
+
ext_tid = ext_progress.add_task("Extracting", aname=archive_path.name, total=len(members := zf.infolist()))
|
151 |
+
extract_path = Path(zf.extract(members[0], download_dir))
|
152 |
+
if extract_type == "file":
|
153 |
+
assert len(members) == 1, f"Archive `{archive_path}` with extract type `{extract_type} has > 1 member!"
|
154 |
+
elif extract_type == "directory":
|
155 |
+
for member in members[1:]:
|
156 |
+
zf.extract(member, download_dir)
|
157 |
+
ext_progress.advance(ext_tid)
|
158 |
+
else:
|
159 |
+
raise ValueError(f"Extract type `{extract_type}` for archive `{archive_path}` is not defined!")
|
160 |
+
|
161 |
+
# Cleanup (if specified)
|
162 |
+
if cleanup:
|
163 |
+
archive_path.unlink()
|
164 |
+
|
165 |
+
return extract_path
|
166 |
+
|
167 |
+
|
168 |
+
def download_extract(dataset_id: str, root_dir: Path) -> None:
|
169 |
+
"""Download all files for a given dataset (querying registry above), extracting archives if necessary."""
|
170 |
+
os.makedirs(download_dir := root_dir / "download" / dataset_id, exist_ok=True)
|
171 |
+
|
172 |
+
# Download Files => Single-Threaded, with Progress Bar
|
173 |
+
dl_tasks = [d for d in DATASET_REGISTRY[dataset_id] if not (download_dir / d["name"]).exists()]
|
174 |
+
for dl_task in dl_tasks:
|
175 |
+
dl_path = download_with_progress(dl_task["url"], download_dir)
|
176 |
+
|
177 |
+
# Extract Files (if specified) --> Note (assumes ".zip" ONLY!)
|
178 |
+
if dl_task["extract"]:
|
179 |
+
dl_path = extract_with_progress(dl_path, download_dir, dl_task["extract_type"])
|
180 |
+
dl_path = dl_path.parent if dl_path.is_file() else dl_path
|
181 |
+
|
182 |
+
# Rename Path --> dl_task["name"]
|
183 |
+
if dl_task["do_rename"]:
|
184 |
+
shutil.move(dl_path, download_dir / dl_task["name"])
|
185 |
+
if __name__ == "__main__":
|
186 |
+
import sys
|
187 |
+
from pathlib import Path
|
188 |
+
|
189 |
+
# 设置根目录
|
190 |
+
root_dir = Path("./data") # 这里设置一个默认的下载目录
|
191 |
+
os.makedirs(root_dir, exist_ok=True)
|
192 |
+
|
193 |
+
# 下载所有数据集
|
194 |
+
for dataset_id in DATASET_REGISTRY.keys():
|
195 |
+
print(f"开始下载数据集: {dataset_id}")
|
196 |
+
download_extract(dataset_id, root_dir)
|
197 |
+
|
198 |
+
print("所有数据集下载完成!")
|
deepspeed_pretrain.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
import shutil
|
4 |
+
os.environ['HF_ENDPOINT']="https://hf-mirror.com"
|
5 |
+
from qwenva import tokenizer
|
6 |
+
from qwenva import processor
|
7 |
+
from qwenva import qwenva
|
8 |
+
images_file_path="/root/autodl-tmp/images"
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import Dataset, DataLoader
|
12 |
+
import os
|
13 |
+
import json
|
14 |
+
from PIL import Image
|
15 |
+
import json
|
16 |
+
with open('/root/autodl-tmp/chat.json', 'r', encoding='utf-8') as f:
|
17 |
+
chat_data = json.load(f)
|
18 |
+
image_token=tokenizer.encode('<image>')[0]
|
19 |
+
pad_token=tokenizer.pad_token_id
|
20 |
+
image_token=tokenizer.encode('<image>')[0]
|
21 |
+
pad_token=tokenizer.pad_token_id
|
22 |
+
def process_data(sample,max_len=8012):
|
23 |
+
conversations=sample['conversations']
|
24 |
+
labels=[]
|
25 |
+
input_ids=[]
|
26 |
+
flag=0
|
27 |
+
messages=[]
|
28 |
+
input_ids=[]
|
29 |
+
for index,item in enumerate(conversations):
|
30 |
+
if item['from']=='human':
|
31 |
+
old_input_ids=input_ids
|
32 |
+
messages.append({'role':'user','content':item['value']})
|
33 |
+
input_ids=tokenizer.apply_chat_template(
|
34 |
+
messages,
|
35 |
+
add_generation_prompt=True
|
36 |
+
)
|
37 |
+
#input_ids+=input_token[]
|
38 |
+
labels+=[-100]*(len(input_ids)-len(old_input_ids))
|
39 |
+
if index==flag:
|
40 |
+
try:
|
41 |
+
image_index=input_ids.index(image_token)
|
42 |
+
labels[image_index]=image_token
|
43 |
+
except ValueError:
|
44 |
+
print("image token not found")
|
45 |
+
flag=index+1
|
46 |
+
continue
|
47 |
+
elif item['from']=='gpt':
|
48 |
+
old_input_ids=input_ids
|
49 |
+
messages.append({'role':'assistant','content':item['value']})
|
50 |
+
input_ids=tokenizer.apply_chat_template(
|
51 |
+
messages
|
52 |
+
)
|
53 |
+
flag=index+1
|
54 |
+
labels+=input_ids[len(old_input_ids):]
|
55 |
+
#填充或者截断,使得长度相同
|
56 |
+
if len(input_ids)>max_len:
|
57 |
+
input_ids=input_ids[:max_len]
|
58 |
+
labels=labels[:max_len]
|
59 |
+
attention_mask=[1]*len(input_ids)
|
60 |
+
else:
|
61 |
+
attention_mask=[1]*len(input_ids)+[0]*(max_len-len(input_ids))
|
62 |
+
input_ids+=[pad_token]*(max_len-len(input_ids))
|
63 |
+
labels+=[-100]*(max_len-len(labels))
|
64 |
+
#转化为张量
|
65 |
+
input_ids=torch.tensor(input_ids)
|
66 |
+
attention_mask=torch.tensor(attention_mask)
|
67 |
+
labels=torch.tensor(labels)
|
68 |
+
image_index=torch.tensor(image_index)
|
69 |
+
return {
|
70 |
+
'input_ids':input_ids,
|
71 |
+
'attention_mask':attention_mask,
|
72 |
+
'labels':labels,
|
73 |
+
'image_idx':image_index
|
74 |
+
}
|
75 |
+
|
76 |
+
|
77 |
+
import os
|
78 |
+
import torch
|
79 |
+
from torch.utils.data import Dataset
|
80 |
+
from PIL import Image
|
81 |
+
class MyDataset(Dataset):
|
82 |
+
def __init__(self, images_file_path,data,max_len=1024):
|
83 |
+
self.max_len=max_len
|
84 |
+
self.images_file_path = images_file_path
|
85 |
+
self.data = data
|
86 |
+
self.max_len=max_len
|
87 |
+
def __len__(self):
|
88 |
+
return len(self.data)
|
89 |
+
def __getitem__(self, index):
|
90 |
+
output_=process_data(self.data[index],max_len=self.max_len)
|
91 |
+
img_path=os.path.join(self.images_file_path,self.data[index]['image'])
|
92 |
+
img=Image.open(img_path)
|
93 |
+
input_pixel= processor(images=img, return_tensors="pt")
|
94 |
+
output_['input_pixel']=input_pixel['pixel_values'].squeeze()
|
95 |
+
return output_
|
96 |
+
dataset=MyDataset(images_file_path,chat_data,max_len=360)
|
97 |
+
train_loader=DataLoader(dataset,batch_size=8,shuffle=True)
|
98 |
+
import deepspeed
|
99 |
+
import argparse
|
100 |
+
# 设置设备
|
101 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
102 |
+
qwenva=qwenva.to(device)
|
103 |
+
model_engine,optimizer,_,_=deepspeed.initialize(
|
104 |
+
model=qwenva,
|
105 |
+
args=argparse.Namespace(),
|
106 |
+
model_parameters=qwenva.parameters(),
|
107 |
+
config_params="./deepspeed_config.json"
|
108 |
+
)
|
109 |
+
#checkpoint_path = "/root/autodl-tmp/best_model_2"
|
110 |
+
#model_engine.load_checkpoint(checkpoint_path)
|
111 |
+
import torch.optim as optim
|
112 |
+
import torch.nn as nn
|
113 |
+
from torch.amp import autocast, GradScaler
|
114 |
+
#optimizer = optim.Adam(model.parameters(), lr=0.001)
|
115 |
+
loss_fn = nn.CrossEntropyLoss()
|
116 |
+
#eps = 1e-8
|
117 |
+
accumulation_steps = 2
|
118 |
+
# 训练函数
|
119 |
+
def train(model_engine, train_dataloader, optimizer, loss_fn, device, epochs):
|
120 |
+
model_engine.train()
|
121 |
+
#model_engine.to(device)
|
122 |
+
for epoch in range(epochs):
|
123 |
+
# 使用 tqdm 显示进度条
|
124 |
+
with tqdm(total=len(train_dataloader), desc=f'Epoch {epoch + 1}/{epochs}', unit='batch') as pbar:
|
125 |
+
optimizer.zero_grad()
|
126 |
+
for batch_idx, batch in enumerate(train_dataloader):
|
127 |
+
# 将数据拷贝到 GPU 上
|
128 |
+
input_ids = batch['input_ids'].to(device)
|
129 |
+
attention_mask = batch['attention_mask'].to(device)
|
130 |
+
input_pixel = batch['input_pixel'].to(device)
|
131 |
+
labels = batch['labels'].to(device)
|
132 |
+
image_idx=batch['image_idx'].to(device)
|
133 |
+
logits = model_engine(input_ids, attention_mask, input_pixel,image_idx)
|
134 |
+
# 计算损失
|
135 |
+
#max_logits= logits.max(dim=-1, keepdim=True)[0] # 计算最大值
|
136 |
+
#stable_logits= logits - max_logits # 减去最大值得到数值稳定的值
|
137 |
+
loss= loss_fn(logits[:, :-1, :].reshape(-1, logits.shape[-1]), labels[:, 1:].reshape(-1).clone())
|
138 |
+
# 反向传播
|
139 |
+
model_engine.backward(loss)
|
140 |
+
if (batch_idx+1)%accumulation_steps==0:
|
141 |
+
model_engine.step()
|
142 |
+
pbar.update(1)
|
143 |
+
pbar.set_postfix(loss=loss.item()) # 显示当前损失
|
144 |
+
if (batch_idx+1)%24807==0:
|
145 |
+
# 如果文件夹存在,则删除并重新创建
|
146 |
+
if os.path.exists("/root/autodl-tmp/best_model_instruct"):
|
147 |
+
shutil.rmtree("/root/autodl-tmp/best_model_instruct") # 删除文件夹及其内容
|
148 |
+
os.makedirs("/root/autodl-tmp/best_model_instruct") # 重新创建文件夹
|
149 |
+
model_engine.save_checkpoint("/root/autodl-tmp/best_model_instruct")
|
150 |
+
print(f" model saved at batch {batch_idx+1}")
|
151 |
+
train(model_engine, train_loader, optimizer, loss_fn, device, epochs=1)
|
152 |
+
|
deepspeed_train_150k.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import deepspeed
|
3 |
+
#deepspeed.initialize(config="./deepspeed_config.json",log_level='DEBUG')
|
4 |
+
from tqdm import tqdm
|
5 |
+
import shutil
|
6 |
+
os.environ['HF_ENDPOINT']="https://hf-mirror.com"
|
7 |
+
from qwenva import tokenizer
|
8 |
+
from qwenva import processor
|
9 |
+
from qwenva import qwenva
|
10 |
+
images_file_path='./data/download/llava-v1.5-instruct/coco/train2017'
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import Dataset, DataLoader
|
14 |
+
import os
|
15 |
+
import json
|
16 |
+
from PIL import Image
|
17 |
+
import json
|
18 |
+
with open('/root/autodl-tmp/LLaVA-Instruct-150K/llava_instruct_150k.json', 'r', encoding='utf-8') as f:
|
19 |
+
chat_data = json.load(f)
|
20 |
+
import torch
|
21 |
+
image_token=tokenizer.encode('<image>')[0]
|
22 |
+
pad_token=tokenizer.pad_token_id
|
23 |
+
image_token=tokenizer.encode('<image>')[0]
|
24 |
+
pad_token=tokenizer.pad_token_id
|
25 |
+
def process_data(sample,max_len=8012):
|
26 |
+
conversations=sample['conversations']
|
27 |
+
labels=[]
|
28 |
+
input_ids=[]
|
29 |
+
flag=0
|
30 |
+
messages=[]
|
31 |
+
input_ids=[]
|
32 |
+
try:
|
33 |
+
for index,item in enumerate(conversations):
|
34 |
+
if item['from']=='human':
|
35 |
+
old_input_ids=input_ids
|
36 |
+
messages.append({'role':'user','content':item['value']})
|
37 |
+
input_ids=tokenizer.apply_chat_template(
|
38 |
+
messages,
|
39 |
+
add_generation_prompt=True
|
40 |
+
)
|
41 |
+
#input_ids+=input_token[]
|
42 |
+
labels+=[-100]*(len(input_ids)-len(old_input_ids))
|
43 |
+
if index==flag:
|
44 |
+
try:
|
45 |
+
image_index=input_ids.index(image_token)
|
46 |
+
labels[image_index]=image_token
|
47 |
+
except ValueError:
|
48 |
+
print("image token not found")
|
49 |
+
flag=index+1
|
50 |
+
continue
|
51 |
+
elif item['from']=='gpt':
|
52 |
+
old_input_ids=input_ids
|
53 |
+
messages.append({'role':'assistant','content':item['value']})
|
54 |
+
input_ids=tokenizer.apply_chat_template(
|
55 |
+
messages
|
56 |
+
)
|
57 |
+
labels+=input_ids[len(old_input_ids):]
|
58 |
+
except:
|
59 |
+
print("error in process_data_1")
|
60 |
+
exit()
|
61 |
+
#填充或者截断,使得长度相同
|
62 |
+
try:
|
63 |
+
if len(input_ids)>max_len:
|
64 |
+
input_ids=input_ids[:max_len]
|
65 |
+
labels=labels[:max_len]
|
66 |
+
attention_mask=[1]*len(input_ids)
|
67 |
+
else:
|
68 |
+
attention_mask=[1]*len(input_ids)+[0]*(max_len-len(input_ids))
|
69 |
+
input_ids+=[pad_token]*(max_len-len(input_ids))
|
70 |
+
labels+=[-100]*(max_len-len(labels))
|
71 |
+
except:
|
72 |
+
print("error in process_data_2")
|
73 |
+
exit()
|
74 |
+
#转化为张量
|
75 |
+
try:
|
76 |
+
input_ids=torch.tensor(input_ids)
|
77 |
+
attention_mask=torch.tensor(attention_mask)
|
78 |
+
labels=torch.tensor(labels)
|
79 |
+
image_index=torch.tensor(image_index)
|
80 |
+
except:
|
81 |
+
print("error in tensor")
|
82 |
+
exit()
|
83 |
+
return {
|
84 |
+
'input_ids':input_ids,
|
85 |
+
'attention_mask':attention_mask,
|
86 |
+
'labels':labels,
|
87 |
+
'image_idx':image_index
|
88 |
+
}
|
89 |
+
|
90 |
+
|
91 |
+
import os
|
92 |
+
import torch
|
93 |
+
from torch.utils.data import Dataset
|
94 |
+
from PIL import Image
|
95 |
+
class MyDataset(Dataset):
|
96 |
+
def __init__(self, images_file_path,data,max_len=1024):
|
97 |
+
self.max_len=max_len
|
98 |
+
self.images_file_path = images_file_path
|
99 |
+
self.data = data
|
100 |
+
self.max_len=max_len
|
101 |
+
def __len__(self):
|
102 |
+
return len(self.data)
|
103 |
+
def __getitem__(self, index):
|
104 |
+
output_=process_data(self.data[index],max_len=self.max_len)
|
105 |
+
img_path=os.path.join(self.images_file_path,self.data[index]['image'])
|
106 |
+
try:
|
107 |
+
img=Image.open(img_path)
|
108 |
+
except:
|
109 |
+
print(f"image {img_path} not found")
|
110 |
+
output_['labels']=torch.tensor([-100]*self.max_len)
|
111 |
+
input_pixel= processor(images=img, return_tensors="pt")
|
112 |
+
output_['input_pixel']=input_pixel['pixel_values'].squeeze()
|
113 |
+
return output_
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
dataset=MyDataset(images_file_path,chat_data,max_len=2048)
|
118 |
+
train_loader=DataLoader(dataset,batch_size=8,shuffle=True)
|
119 |
+
import argparse
|
120 |
+
# 设置设备
|
121 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
122 |
+
qwenva=qwenva.to(device)
|
123 |
+
model_engine,optimizer,_,_=deepspeed.initialize(
|
124 |
+
model=qwenva,
|
125 |
+
args=argparse.Namespace(),
|
126 |
+
model_parameters=qwenva.parameters(),
|
127 |
+
config_params="./deepspeed_config.json"
|
128 |
+
)
|
129 |
+
#checkpoint_path = "./best_model_2"
|
130 |
+
#model_engine.load_checkpoint(checkpoint_path)
|
131 |
+
#保存编译模型权重
|
132 |
+
#torch.save(model_engine.module.state_dict(), "./compiled_model.pth")
|
133 |
+
for name, param in model_engine.module._orig_mod.text_embedding.named_parameters():
|
134 |
+
param.requires_grad = True
|
135 |
+
#print("embedding权重梯度打开:",name)
|
136 |
+
#for name, param in model_engine.module._orig_mod.align_layer.named_parameters():
|
137 |
+
#param.requires_grad = True
|
138 |
+
#print("align_layer权重梯度打开",name)
|
139 |
+
|
140 |
+
for name,param in model_engine.module._orig_mod.lm_head.named_parameters():
|
141 |
+
param.requires_grad = True
|
142 |
+
#print("lm_head权重梯度打开",name)
|
143 |
+
|
144 |
+
for name,param in model_engine.module._orig_mod.transformer.named_parameters():
|
145 |
+
param.requires_grad = True
|
146 |
+
#print("transformer权重梯度打开",name)
|
147 |
+
for name,param in model_engine.module._orig_mod.named_parameters():
|
148 |
+
if param.requires_grad:
|
149 |
+
print(f"Layer: {name}, Requires Grad: {param.requires_grad}")
|
150 |
+
|
151 |
+
|
152 |
+
#optimizer = optim.Adam(model.parameters(), lr=0.001)
|
153 |
+
import torch.nn as nn
|
154 |
+
loss_fn = nn.CrossEntropyLoss()
|
155 |
+
#eps = 1e-8
|
156 |
+
accumulation_steps = 1
|
157 |
+
# 训练函数
|
158 |
+
def train(model_engine, train_dataloader, loss_fn, device, epochs):
|
159 |
+
model_engine.train()
|
160 |
+
#model_engine.to(device)
|
161 |
+
for epoch in range(epochs):
|
162 |
+
# 使用 tqdm 显示进度条
|
163 |
+
with tqdm(total=len(train_dataloader), desc=f'Epoch {epoch + 1}/{epochs}', unit='batch') as pbar:
|
164 |
+
#optimizer.zero_grad()
|
165 |
+
try:
|
166 |
+
for batch_idx, batch in enumerate(train_dataloader):
|
167 |
+
# 将数据拷贝到 GPU 上
|
168 |
+
input_ids = batch['input_ids'].to(device)
|
169 |
+
attention_mask = batch['attention_mask'].to(device)
|
170 |
+
input_pixel = batch['input_pixel'].to(device)
|
171 |
+
labels = batch['labels'].to(device)
|
172 |
+
image_idx=batch['image_idx'].to(device)
|
173 |
+
logits = model_engine(input_ids, attention_mask, input_pixel,image_idx)
|
174 |
+
# 计算损失
|
175 |
+
max_logits= logits.max(dim=-1, keepdim=True)[0] # 计算最大值
|
176 |
+
stable_logits= logits - max_logits # 减去最大值得到数值稳定的值
|
177 |
+
loss= loss_fn(stable_logits[:, :-1, :].reshape(-1, stable_logits.shape[-1]), labels[:, 1:].reshape(-1).clone())
|
178 |
+
model_engine.backward(loss)
|
179 |
+
if (batch_idx+1)%accumulation_steps==0:
|
180 |
+
model_engine.step()
|
181 |
+
pbar.update(1)
|
182 |
+
pbar.set_postfix(loss=loss.item()) # 显示当前损失
|
183 |
+
if (batch_idx+1)%6000==0:
|
184 |
+
# 如果文件夹存在,则删除并重新创建
|
185 |
+
if os.path.exists("./best_model_2"):
|
186 |
+
shutil.rmtree("./best_model_2") # 删除文件夹及其内容
|
187 |
+
os.makedirs("./best_model_2") # 重新创建文件夹
|
188 |
+
model_engine.save_checkpoint("./best_model_2")
|
189 |
+
torch.save(model_engine.module.state_dict(), "./compiled_model_2.pth")
|
190 |
+
print(f" model saved at batch {batch_idx+1}")
|
191 |
+
except Exception as e:
|
192 |
+
print(f"error in train {e}")
|
193 |
+
|
194 |
+
train(model_engine, train_loader, loss_fn, device, epochs=2)
|
195 |
+
|
deepspeed_train_665k.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import deepspeed
|
3 |
+
#deepspeed.initialize(config="./deepspeed_config.json",log_level='DEBUG')
|
4 |
+
from tqdm import tqdm
|
5 |
+
import shutil
|
6 |
+
os.environ['HF_ENDPOINT']="https://hf-mirror.com"
|
7 |
+
from qwenva import tokenizer
|
8 |
+
from qwenva import processor
|
9 |
+
from qwenva import qwenva
|
10 |
+
images_file_path='./data/download/llava-v1.5-instruct'
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import Dataset, DataLoader
|
14 |
+
import os
|
15 |
+
import json
|
16 |
+
from PIL import Image
|
17 |
+
import json
|
18 |
+
with open('/root/autodl-tmp/LLaVA-Instruct-150K/qwenva_mix665k.json', 'r', encoding='utf-8') as f:
|
19 |
+
chat_data = json.load(f)
|
20 |
+
import torch
|
21 |
+
image_token=tokenizer.encode('<image>')[0]
|
22 |
+
pad_token=tokenizer.pad_token_id
|
23 |
+
image_token=tokenizer.encode('<image>')[0]
|
24 |
+
pad_token=tokenizer.pad_token_id
|
25 |
+
def process_data(sample,max_len=8012):
|
26 |
+
conversations=sample['conversations']
|
27 |
+
labels=[]
|
28 |
+
input_ids=[]
|
29 |
+
flag=0
|
30 |
+
messages=[]
|
31 |
+
input_ids=[]
|
32 |
+
try:
|
33 |
+
for index,item in enumerate(conversations):
|
34 |
+
if item['from']=='human':
|
35 |
+
old_input_ids=input_ids
|
36 |
+
messages.append({'role':'user','content':item['value']})
|
37 |
+
input_ids=tokenizer.apply_chat_template(
|
38 |
+
messages,
|
39 |
+
add_generation_prompt=True
|
40 |
+
)
|
41 |
+
#input_ids+=input_token[]
|
42 |
+
labels+=[-100]*(len(input_ids)-len(old_input_ids))
|
43 |
+
if index==flag:
|
44 |
+
if image_token in input_ids:
|
45 |
+
image_index=input_ids.index(image_token)
|
46 |
+
labels[image_index]=image_token
|
47 |
+
else:
|
48 |
+
image_index=-100
|
49 |
+
elif item['from']=='gpt':
|
50 |
+
old_input_ids=input_ids
|
51 |
+
messages.append({'role':'assistant','content':item['value']})
|
52 |
+
input_ids=tokenizer.apply_chat_template(
|
53 |
+
messages
|
54 |
+
)
|
55 |
+
labels+=input_ids[len(old_input_ids):]
|
56 |
+
except:
|
57 |
+
print("error in process_data_1")
|
58 |
+
exit()
|
59 |
+
#填充或者截断,使得长度相同
|
60 |
+
try:
|
61 |
+
if len(input_ids)>max_len:
|
62 |
+
input_ids=input_ids[:max_len]
|
63 |
+
labels=labels[:max_len]
|
64 |
+
attention_mask=[1]*len(input_ids)
|
65 |
+
else:
|
66 |
+
attention_mask=[1]*len(input_ids)+[0]*(max_len-len(input_ids))
|
67 |
+
input_ids+=[pad_token]*(max_len-len(input_ids))
|
68 |
+
labels+=[-100]*(max_len-len(labels))
|
69 |
+
except:
|
70 |
+
print("error in process_data_2")
|
71 |
+
exit()
|
72 |
+
#转化为张量
|
73 |
+
try:
|
74 |
+
input_ids=torch.tensor(input_ids)
|
75 |
+
attention_mask=torch.tensor(attention_mask)
|
76 |
+
labels=torch.tensor(labels)
|
77 |
+
image_index=torch.tensor(image_index)
|
78 |
+
except:
|
79 |
+
print("error in tensor")
|
80 |
+
exit()
|
81 |
+
return {
|
82 |
+
'input_ids':input_ids,
|
83 |
+
'attention_mask':attention_mask,
|
84 |
+
'labels':labels,
|
85 |
+
'image_idx':image_index
|
86 |
+
}
|
87 |
+
|
88 |
+
|
89 |
+
import os
|
90 |
+
import torch
|
91 |
+
from torch.utils.data import Dataset
|
92 |
+
from PIL import Image
|
93 |
+
class MyDataset(Dataset):
|
94 |
+
def __init__(self, images_file_path,data,max_len=1024):
|
95 |
+
self.max_len=max_len
|
96 |
+
self.images_file_path = images_file_path
|
97 |
+
self.data = data
|
98 |
+
self.max_len=max_len
|
99 |
+
def __len__(self):
|
100 |
+
return len(self.data)
|
101 |
+
def __getitem__(self, index):
|
102 |
+
output_=process_data(self.data[index],max_len=self.max_len)
|
103 |
+
if output_['image_idx']!=-100:
|
104 |
+
img_path=os.path.join(self.images_file_path,self.data[index]['image'])
|
105 |
+
img=Image.open(img_path)
|
106 |
+
input_pixel= processor(images=img, return_tensors="pt")
|
107 |
+
output_['input_pixel']=input_pixel['pixel_values'].squeeze()
|
108 |
+
else:
|
109 |
+
output_['input_pixel']=torch.zeros(3,224,224).to(device=output_['input_ids'].device,dtype=output_['input_ids'].dtype)
|
110 |
+
return output_
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
dataset=MyDataset(images_file_path,chat_data,max_len=2048)
|
115 |
+
train_loader=DataLoader(dataset,batch_size=16,shuffle=True)
|
116 |
+
import argparse
|
117 |
+
model_engine,optimizer,_,_=deepspeed.initialize(
|
118 |
+
model=qwenva,
|
119 |
+
args=argparse.Namespace(),
|
120 |
+
model_parameters=qwenva.parameters(),
|
121 |
+
config_params="./deepspeed_config.json"
|
122 |
+
)
|
123 |
+
#checkpoint_path = "./best_model_2"
|
124 |
+
#model_engine.load_checkpoint(checkpoint_path)
|
125 |
+
#保存编译模型权重
|
126 |
+
#torch.save(model_engine.module.state_dict(), "./compiled_model.pth")
|
127 |
+
for name, param in model_engine.module._orig_mod.text_embedding.named_parameters():
|
128 |
+
param.requires_grad = True
|
129 |
+
#print("embedding权重梯度打开:",name)
|
130 |
+
#for name, param in model_engine.module._orig_mod.align_layer.named_parameters():
|
131 |
+
#param.requires_grad = True
|
132 |
+
#print("align_layer权重梯度打开",name)
|
133 |
+
|
134 |
+
for name,param in model_engine.module._orig_mod.lm_head.named_parameters():
|
135 |
+
param.requires_grad = True
|
136 |
+
#print("lm_head权重梯度打开",name)
|
137 |
+
|
138 |
+
for name,param in model_engine.module._orig_mod.transformer.named_parameters():
|
139 |
+
param.requires_grad = True
|
140 |
+
#print("transformer权重梯度打开",name)
|
141 |
+
for name,param in model_engine.module._orig_mod.named_parameters():
|
142 |
+
if param.requires_grad:
|
143 |
+
print(f"Layer: {name}, Requires Grad: {param.requires_grad}")
|
144 |
+
|
145 |
+
|
146 |
+
#optimizer = optim.Adam(model.parameters(), lr=0.001)
|
147 |
+
import torch.nn as nn
|
148 |
+
loss_fn = nn.CrossEntropyLoss()
|
149 |
+
#eps = 1e-8
|
150 |
+
accumulation_steps = 1
|
151 |
+
# 训练函数
|
152 |
+
def train(model_engine, train_dataloader, loss_fn, device, epochs):
|
153 |
+
model_engine.train()
|
154 |
+
#model_engine.to(device)
|
155 |
+
for epoch in range(epochs):
|
156 |
+
# 使用 tqdm 显示进度条
|
157 |
+
with tqdm(total=len(train_dataloader), desc=f'Epoch {epoch + 1}/{epochs}', unit='batch') as pbar:
|
158 |
+
#optimizer.zero_grad()
|
159 |
+
try:
|
160 |
+
for batch_idx, batch in enumerate(train_dataloader):
|
161 |
+
# 将数据拷贝到 GPU 上
|
162 |
+
input_ids = batch['input_ids'].to(device)
|
163 |
+
attention_mask = batch['attention_mask'].to(device)
|
164 |
+
input_pixel = batch['input_pixel'].to(device)
|
165 |
+
labels = batch['labels'].to(device)
|
166 |
+
image_idx=batch['image_idx'].to(device)
|
167 |
+
logits = model_engine(input_ids, attention_mask, input_pixel,image_idx)
|
168 |
+
# 计算损失
|
169 |
+
#max_logits= logits.max(dim=-1, keepdim=True)[0] # 计算最大值
|
170 |
+
#stable_logits= logits - max_logits # 减去最大值得到数值稳定的值
|
171 |
+
loss= loss_fn(logits[:, :-1, :].reshape(-1, logits.shape[-1]), labels[:, 1:].reshape(-1).clone())
|
172 |
+
model_engine.backward(loss)
|
173 |
+
if (batch_idx+1)%accumulation_steps==0:
|
174 |
+
model_engine.step()
|
175 |
+
pbar.update(1)
|
176 |
+
pbar.set_postfix(loss=loss.item()) # 显示当前损失
|
177 |
+
if (batch_idx+1)%4100==0:
|
178 |
+
# 如果文件夹存在,则删除并重新创建
|
179 |
+
if os.path.exists("./best_model_2"):
|
180 |
+
shutil.rmtree("./best_model_2") # 删除文件夹及其内容
|
181 |
+
os.makedirs("./best_model_2") # 重新创建文件夹
|
182 |
+
model_engine.save_checkpoint("./best_model_2")
|
183 |
+
torch.save(model_engine.module.state_dict(), "./compiled_model_3.pth")
|
184 |
+
print(f" model saved at batch {batch_idx+1}")
|
185 |
+
except Exception as e:
|
186 |
+
print(f"error in train {e}")
|
187 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
188 |
+
train(model_engine, train_loader, loss_fn, device, epochs=2)
|
189 |
+
|
download2.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Dict, List, TypedDict
|
5 |
+
from zipfile import ZipFile
|
6 |
+
|
7 |
+
import requests
|
8 |
+
from PIL import Image
|
9 |
+
from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn
|
10 |
+
from tqdm import tqdm
|
11 |
+
"""
|
12 |
+
{
|
13 |
+
"name": "coco/train2017", # Visual Instruct Tuning images are all sourced from COCO Train 2017
|
14 |
+
"extract": True,
|
15 |
+
"extract_type": "directory",
|
16 |
+
"url": "http://images.cocodataset.org/zips/train2017.zip",
|
17 |
+
"do_rename": True,
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"name": "gqa/images",
|
21 |
+
"extract": True,
|
22 |
+
"extract_type": "directory",
|
23 |
+
"url": "https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip",
|
24 |
+
"do_rename": True,
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"name": "ocr_vqa/images",
|
28 |
+
"extract": True,
|
29 |
+
"extract_type": "directory",
|
30 |
+
"url": "https://hf-mirror.com/datasets/qnguyen3/ocr_vqa/resolve/main/ocr_vqa.zip",
|
31 |
+
"do_rename": True,
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"name": "textvqa/train_images",
|
35 |
+
"extract": True,
|
36 |
+
"extract_type": "directory",
|
37 |
+
"url": "https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip",
|
38 |
+
"do_rename": True,
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"name": "vg/VG_100K_2",
|
42 |
+
"extract": True,
|
43 |
+
"extract_type": "directory",
|
44 |
+
"url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip",
|
45 |
+
"do_rename": True,
|
46 |
+
},
|
47 |
+
"""
|
48 |
+
|
49 |
+
# === Dataset Registry w/ Links ===
|
50 |
+
# fmt: off
|
51 |
+
DatasetComponent = TypedDict(
|
52 |
+
"DatasetComponent",
|
53 |
+
{"name": str, "extract": bool, "extract_type": str, "url": str, "do_rename": bool},
|
54 |
+
total=False
|
55 |
+
)
|
56 |
+
|
57 |
+
DATASET_REGISTRY: Dict[str, List[DatasetComponent]] = {
|
58 |
+
# === LLaVa v1.5 Dataset(s) ===
|
59 |
+
|
60 |
+
"llava-v1.5-instruct":[
|
61 |
+
|
62 |
+
{
|
63 |
+
"name": "vg/VG_100K",
|
64 |
+
"extract": True,
|
65 |
+
"extract_type": "directory",
|
66 |
+
"url": "https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip",
|
67 |
+
"do_rename": True,
|
68 |
+
}
|
69 |
+
]
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
}
|
74 |
+
# fmt: on
|
75 |
+
|
76 |
+
|
77 |
+
def convert_to_jpg(image_dir: Path) -> None:
|
78 |
+
"""Handling for OCR-VQA Images specifically; iterates through directory, converts all GIFs/PNGs."""
|
79 |
+
print(f"Converting all Images in `{image_dir}` to JPG")
|
80 |
+
|
81 |
+
for image_fn in tqdm(list(image_dir.iterdir())):
|
82 |
+
jpg_fn = image_dir / f"{image_fn.stem}.jpg" # 创建 JPG 文件名
|
83 |
+
if image_fn.suffix in {".jpg", ".jpeg"} or jpg_fn.exists():
|
84 |
+
continue
|
85 |
+
|
86 |
+
if image_fn.suffix == ".gif":
|
87 |
+
gif = Image.open(image_fn)
|
88 |
+
gif.seek(0)
|
89 |
+
gif.convert("RGB").save(jpg_fn)
|
90 |
+
elif image_fn.suffix == ".png":
|
91 |
+
Image.open(image_fn).convert("RGB").save(jpg_fn)
|
92 |
+
else:
|
93 |
+
raise ValueError(f"Unexpected image format `{image_fn.suffix}`")
|
94 |
+
|
95 |
+
|
96 |
+
import os
|
97 |
+
import shutil
|
98 |
+
from pathlib import Path
|
99 |
+
from typing import Dict, List, TypedDict
|
100 |
+
from zipfile import ZipFile
|
101 |
+
|
102 |
+
import requests
|
103 |
+
from PIL import Image
|
104 |
+
from rich.progress import BarColumn, DownloadColumn, MofNCompleteColumn, Progress, TextColumn, TransferSpeedColumn
|
105 |
+
from tqdm import tqdm
|
106 |
+
|
107 |
+
# DatasetComponent 和 DATASET_REGISTRY 保持不变
|
108 |
+
|
109 |
+
def download_with_progress(url: str, download_dir: Path, chunk_size_bytes: int = 1024) -> Path:
|
110 |
+
"""Utility function for downloading files from the internet, with a handy Rich-based progress bar."""
|
111 |
+
print(f"Downloading {url}")
|
112 |
+
|
113 |
+
dest_path = download_dir / Path(url).name
|
114 |
+
resume_header = {}
|
115 |
+
|
116 |
+
if dest_path.exists():
|
117 |
+
return dest_path
|
118 |
+
|
119 |
+
max_retries = 5
|
120 |
+
for attempt in range(max_retries):
|
121 |
+
try:
|
122 |
+
response = requests.get(url, headers=resume_header, stream=True)
|
123 |
+
if response.status_code not in (200, 206):
|
124 |
+
raise Exception(f"Failed to download. Status code: {response.status_code}")
|
125 |
+
|
126 |
+
# 下载进度条
|
127 |
+
with Progress(
|
128 |
+
TextColumn("[bold]{task.description} - {task.fields[fname]}"),
|
129 |
+
BarColumn(bar_width=None),
|
130 |
+
"[progress.percentage]{task.percentage:>3.1f}%",
|
131 |
+
"•",
|
132 |
+
DownloadColumn(),
|
133 |
+
"•",
|
134 |
+
TransferSpeedColumn(),
|
135 |
+
transient=True,
|
136 |
+
) as dl_progress:
|
137 |
+
dl_tid = dl_progress.add_task(
|
138 |
+
"Downloading", fname=dest_path.name, total=int(response.headers.get("content-length", "None"))
|
139 |
+
)
|
140 |
+
with open(dest_path, "ab") as f: # 以二进制追加模式打开文件
|
141 |
+
for data in response.iter_content(chunk_size=chunk_size_bytes):
|
142 |
+
f.write(data)
|
143 |
+
dl_progress.advance(dl_tid, chunk_size_bytes)
|
144 |
+
|
145 |
+
return dest_path
|
146 |
+
|
147 |
+
except Exception as e:
|
148 |
+
print(f"Attempt {attempt + 1}/{max_retries} failed: {e}")
|
149 |
+
if attempt < max_retries - 1:
|
150 |
+
print("Retrying...")
|
151 |
+
else:
|
152 |
+
raise
|
153 |
+
|
154 |
+
# 其他函数保持不变,main 方法也不变
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
def extract_with_progress(archive_path: Path, download_dir: Path, extract_type: str, cleanup: bool = False) -> Path:
|
159 |
+
"""Utility function for extracting compressed archives, with a handy Rich-based progress bar."""
|
160 |
+
assert archive_path.suffix == ".zip", "Only `.zip` compressed archives are supported for now!"
|
161 |
+
print(f"Extracting {archive_path.name} to `{download_dir}`")
|
162 |
+
|
163 |
+
with Progress(
|
164 |
+
TextColumn("[bold]{task.description} - {task.fields[aname]}"),
|
165 |
+
BarColumn(bar_width=None),
|
166 |
+
"[progress.percentage]{task.percentage:>3.1f}%",
|
167 |
+
"•",
|
168 |
+
MofNCompleteColumn(),
|
169 |
+
transient=True,
|
170 |
+
) as ext_progress:
|
171 |
+
with ZipFile(archive_path) as zf:
|
172 |
+
ext_tid = ext_progress.add_task("Extracting", aname=archive_path.name, total=len(members := zf.infolist()))
|
173 |
+
extract_path = Path(zf.extract(members[0], download_dir))
|
174 |
+
if extract_type == "file":
|
175 |
+
assert len(members) == 1, f"Archive `{archive_path}` with extract type `{extract_type} has > 1 member!"
|
176 |
+
elif extract_type == "directory":
|
177 |
+
for member in members[1:]:
|
178 |
+
zf.extract(member, download_dir)
|
179 |
+
ext_progress.advance(ext_tid)
|
180 |
+
else:
|
181 |
+
raise ValueError(f"Extract type `{extract_type}` for archive `{archive_path}` is not defined!")
|
182 |
+
|
183 |
+
if cleanup:
|
184 |
+
archive_path.unlink()
|
185 |
+
|
186 |
+
return extract_path
|
187 |
+
|
188 |
+
|
189 |
+
def download_extract(dataset_id: str, root_dir: Path) -> None:
|
190 |
+
"""Download all files for a given dataset (querying registry above), extracting archives if necessary."""
|
191 |
+
os.makedirs(download_dir := root_dir / "download" / dataset_id, exist_ok=True)
|
192 |
+
|
193 |
+
# Download Files
|
194 |
+
dl_tasks = [d for d in DATASET_REGISTRY[dataset_id] if not (download_dir / d["name"]).exists()]
|
195 |
+
for dl_task in dl_tasks:
|
196 |
+
dl_path = download_with_progress(dl_task["url"], download_dir)
|
197 |
+
|
198 |
+
if dl_task["extract"]:
|
199 |
+
dl_path = extract_with_progress(dl_path, download_dir, dl_task["extract_type"])
|
200 |
+
dl_path = dl_path.parent if dl_path.is_file() else dl_path
|
201 |
+
|
202 |
+
if dl_task["do_rename"]:
|
203 |
+
shutil.move(dl_path, download_dir / dl_task["name"])
|
204 |
+
if __name__ == "__main__":
|
205 |
+
import sys
|
206 |
+
from pathlib import Path
|
207 |
+
|
208 |
+
# 设置根目录
|
209 |
+
root_dir = Path("./data") # 这里设置一个默认的下载目录
|
210 |
+
os.makedirs(root_dir, exist_ok=True)
|
211 |
+
|
212 |
+
# 下载所有数据集
|
213 |
+
for dataset_id in DATASET_REGISTRY.keys():
|
214 |
+
print(f"开始下载数据集: {dataset_id}")
|
215 |
+
download_extract(dataset_id, root_dir)
|
216 |
+
|
217 |
+
print("所有数据集下载完成!")
|