|
from torch.utils.data import Dataset, ConcatDataset |
|
import os |
|
from concurrent.futures import ProcessPoolExecutor |
|
import pandas as pd |
|
|
|
|
|
def add_data_args(parent_args): |
|
parser = parent_args.add_argument_group('taiyi stable diffusion data args') |
|
|
|
parser.add_argument( |
|
"--datasets_path", type=str, default=None, required=True, nargs='+', |
|
help="A folder containing the training data of instance images.", |
|
) |
|
parser.add_argument( |
|
"--datasets_type", type=str, default=None, required=True, choices=['txt', 'csv', 'fs_datasets'], nargs='+', |
|
help="dataset type, txt or csv, same len as datasets_path", |
|
) |
|
parser.add_argument( |
|
"--resolution", type=int, default=512, |
|
help=( |
|
"The resolution for input images, all the images in the train/validation dataset will be resized to this" |
|
" resolution" |
|
), |
|
) |
|
parser.add_argument( |
|
"--center_crop", action="store_true", default=False, |
|
help="Whether to center crop images before resizing to resolution" |
|
) |
|
parser.add_argument("--thres", type=float, default=0.2) |
|
return parent_args |
|
|
|
|
|
class TXTDataset(Dataset): |
|
|
|
def __init__(self, |
|
foloder_name, |
|
thres=0.2): |
|
super().__init__() |
|
|
|
self.image_paths = [] |
|
''' |
|
暂时没有开源这部分文件 |
|
score_data = pd.read_csv(os.path.join(foloder_name, 'score.csv')) |
|
img_path2score = {score_data['image_path'][i]: score_data['score'][i] |
|
for i in range(len(score_data))} |
|
''' |
|
|
|
|
|
for each_file in os.listdir(foloder_name): |
|
if each_file.endswith('.jpg'): |
|
self.image_paths.append(os.path.join(foloder_name, each_file)) |
|
|
|
|
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def __getitem__(self, idx): |
|
img_path = str(self.image_paths[idx]) |
|
caption_path = img_path.replace('.jpg', '.txt') |
|
with open(caption_path, 'r') as f: |
|
caption = f.read() |
|
return {'img_path': img_path, 'caption': caption} |
|
|
|
|
|
|
|
class CSVDataset(Dataset): |
|
def __init__(self, |
|
input_filename, |
|
image_root, |
|
img_key, |
|
caption_key, |
|
thres=0.2): |
|
super().__init__() |
|
|
|
print(f'Loading csv data from {input_filename}.') |
|
self.images = [] |
|
self.captions = [] |
|
|
|
if input_filename.endswith('.csv'): |
|
|
|
df = pd.read_csv(input_filename, index_col=0, on_bad_lines='skip') |
|
print(f'file {input_filename} datalen {len(df)}') |
|
|
|
self.images.extend(df[img_key].tolist()) |
|
self.captions.extend(df[caption_key].tolist()) |
|
self.image_root = image_root |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
def __getitem__(self, idx): |
|
img_path = os.path.join(self.image_root, str(self.images[idx])) |
|
return {'img_path': img_path, 'caption': self.captions[idx]} |
|
|
|
|
|
def if_final_dir(path: str) -> bool: |
|
|
|
for f in os.scandir(path): |
|
if f.is_file(): |
|
return True |
|
return False |
|
|
|
|
|
def process_pool_read_txt_dataset(args, |
|
input_root=None, |
|
thres=0.2): |
|
p = ProcessPoolExecutor(max_workers=20) |
|
all_datasets = [] |
|
res = [] |
|
|
|
|
|
def traversal_files(path: str): |
|
list_subfolders_with_paths = [f.path for f in os.scandir(path) if f.is_dir()] |
|
for dir_path in list_subfolders_with_paths: |
|
if if_final_dir(dir_path): |
|
res.append(p.submit(TXTDataset, |
|
dir_path, |
|
thres)) |
|
else: |
|
traversal_files(dir_path) |
|
traversal_files(input_root) |
|
p.shutdown() |
|
for future in res: |
|
all_datasets.append(future.result()) |
|
dataset = ConcatDataset(all_datasets) |
|
return dataset |
|
|
|
|
|
def process_pool_read_csv_dataset(args, |
|
input_root, |
|
thres=0.20): |
|
|
|
all_csvs = os.listdir(os.path.join(input_root, 'release')) |
|
image_root = os.path.join(input_root, 'images') |
|
|
|
all_datasets = [] |
|
res = [] |
|
p = ProcessPoolExecutor(max_workers=150) |
|
for path in all_csvs: |
|
each_csv_path = os.path.join(input_root, 'release', path) |
|
res.append(p.submit(CSVDataset, |
|
each_csv_path, |
|
image_root, |
|
img_key="name", |
|
caption_key="caption", |
|
thres=thres)) |
|
p.shutdown() |
|
for future in res: |
|
all_datasets.append(future.result()) |
|
dataset = ConcatDataset(all_datasets) |
|
return dataset |
|
|
|
|
|
def load_data(args, global_rank=0): |
|
assert len(args.datasets_path) == len(args.datasets_type), \ |
|
"datasets_path num not equal to datasets_type" |
|
all_datasets = [] |
|
for path, type in zip(args.datasets_path, args.datasets_type): |
|
if type == 'txt': |
|
all_datasets.append(process_pool_read_txt_dataset( |
|
args, input_root=path, thres=args.thres)) |
|
elif type == 'csv': |
|
all_datasets.append(process_pool_read_csv_dataset( |
|
args, input_root=path, thres=args.thres)) |
|
elif type == 'fs_datasets': |
|
from fengshen.data.fs_datasets import load_dataset |
|
all_datasets.append(load_dataset(path, num_proc=args.num_workers, |
|
thres=args.thres, global_rank=global_rank)['train']) |
|
else: |
|
raise ValueError('unsupport dataset type: %s' % type) |
|
print(f'load datasset {type} {path} len {len(all_datasets[-1])}') |
|
return {'train': ConcatDataset(all_datasets)} |
|
|