skf15963's picture
Duplicate from fclong/summary
fb238e8
raw
history blame
6.73 kB
from torch.utils.data import Dataset, ConcatDataset
import os
from concurrent.futures import ProcessPoolExecutor
import pandas as pd
def add_data_args(parent_args):
parser = parent_args.add_argument_group('taiyi stable diffusion data args')
# 支持传入多个路径,分别加载
parser.add_argument(
"--datasets_path", type=str, default=None, required=True, nargs='+',
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--datasets_type", type=str, default=None, required=True, choices=['txt', 'csv', 'fs_datasets'], nargs='+',
help="dataset type, txt or csv, same len as datasets_path",
)
parser.add_argument(
"--resolution", type=int, default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop", action="store_true", default=False,
help="Whether to center crop images before resizing to resolution"
)
parser.add_argument("--thres", type=float, default=0.2)
return parent_args
class TXTDataset(Dataset):
# 添加Txt数据集读取,主要是针对Zero23m数据集。
def __init__(self,
foloder_name,
thres=0.2):
super().__init__()
# print(f'Loading folder data from {foloder_name}.')
self.image_paths = []
'''
暂时没有开源这部分文件
score_data = pd.read_csv(os.path.join(foloder_name, 'score.csv'))
img_path2score = {score_data['image_path'][i]: score_data['score'][i]
for i in range(len(score_data))}
'''
# print(img_path2score)
# 这里都存的是地址,避免初始化时间过多。
for each_file in os.listdir(foloder_name):
if each_file.endswith('.jpg'):
self.image_paths.append(os.path.join(foloder_name, each_file))
# print('Done loading data. Len of images:', len(self.image_paths))
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = str(self.image_paths[idx])
caption_path = img_path.replace('.jpg', '.txt') # 图片名称和文本名称一致。
with open(caption_path, 'r') as f:
caption = f.read()
return {'img_path': img_path, 'caption': caption}
# NOTE 加速读取数据,直接用原版的,在外部使用并行读取策略。30min->3min
class CSVDataset(Dataset):
def __init__(self,
input_filename,
image_root,
img_key,
caption_key,
thres=0.2):
super().__init__()
# logging.debug(f'Loading csv data from {input_filename}.')
print(f'Loading csv data from {input_filename}.')
self.images = []
self.captions = []
if input_filename.endswith('.csv'):
# print(f"Load Data from{input_filename}")
df = pd.read_csv(input_filename, index_col=0, on_bad_lines='skip')
print(f'file {input_filename} datalen {len(df)}')
# 这个图片的路径也需要根据数据集的结构稍微做点修改
self.images.extend(df[img_key].tolist())
self.captions.extend(df[caption_key].tolist())
self.image_root = image_root
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = os.path.join(self.image_root, str(self.images[idx]))
return {'img_path': img_path, 'caption': self.captions[idx]}
def if_final_dir(path: str) -> bool:
# 如果当前目录有一个文件,那就算是终极目录
for f in os.scandir(path):
if f.is_file():
return True
return False
def process_pool_read_txt_dataset(args,
input_root=None,
thres=0.2):
p = ProcessPoolExecutor(max_workers=20)
all_datasets = []
res = []
# 遍历该目录下所有的子目录
def traversal_files(path: str):
list_subfolders_with_paths = [f.path for f in os.scandir(path) if f.is_dir()]
for dir_path in list_subfolders_with_paths:
if if_final_dir(dir_path):
res.append(p.submit(TXTDataset,
dir_path,
thres))
else:
traversal_files(dir_path)
traversal_files(input_root)
p.shutdown()
for future in res:
all_datasets.append(future.result())
dataset = ConcatDataset(all_datasets)
return dataset
def process_pool_read_csv_dataset(args,
input_root,
thres=0.20):
# here input_filename is a directory containing a CSV file
all_csvs = os.listdir(os.path.join(input_root, 'release'))
image_root = os.path.join(input_root, 'images')
# csv_with_score = [each for each in all_csvs if 'score' in each]
all_datasets = []
res = []
p = ProcessPoolExecutor(max_workers=150)
for path in all_csvs:
each_csv_path = os.path.join(input_root, 'release', path)
res.append(p.submit(CSVDataset,
each_csv_path,
image_root,
img_key="name",
caption_key="caption",
thres=thres))
p.shutdown()
for future in res:
all_datasets.append(future.result())
dataset = ConcatDataset(all_datasets)
return dataset
def load_data(args, global_rank=0):
assert len(args.datasets_path) == len(args.datasets_type), \
"datasets_path num not equal to datasets_type"
all_datasets = []
for path, type in zip(args.datasets_path, args.datasets_type):
if type == 'txt':
all_datasets.append(process_pool_read_txt_dataset(
args, input_root=path, thres=args.thres))
elif type == 'csv':
all_datasets.append(process_pool_read_csv_dataset(
args, input_root=path, thres=args.thres))
elif type == 'fs_datasets':
from fengshen.data.fs_datasets import load_dataset
all_datasets.append(load_dataset(path, num_proc=args.num_workers,
thres=args.thres, global_rank=global_rank)['train'])
else:
raise ValueError('unsupport dataset type: %s' % type)
print(f'load datasset {type} {path} len {len(all_datasets[-1])}')
return {'train': ConcatDataset(all_datasets)}