Spaces:
Running
Running
import os | |
import json | |
from PIL import Image | |
import pickle | |
import imageio | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
import random | |
from datasets import register | |
import math | |
import torch.distributed as dist | |
from torch.utils.data import BatchSampler | |
from torch.utils.data._utils.collate import default_collate | |
class ImageFolder(Dataset): | |
def __init__(self, path, split_file=None, split_key=None, first_k=None, size=None, | |
repeat=1, cache='none', mask=False): | |
self.repeat = repeat | |
self.cache = cache | |
self.path = path | |
self.Train = False | |
self.split_key = split_key | |
self.size = size | |
self.mask = mask | |
if self.mask: | |
self.img_transform = transforms.Compose([ | |
transforms.Resize((self.size, self.size), interpolation=Image.NEAREST), | |
transforms.ToTensor(), | |
]) | |
else: | |
self.img_transform = transforms.Compose([ | |
transforms.Resize((self.size, self.size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
if split_file is None: | |
filenames = sorted(os.listdir(path)) | |
else: | |
with open(split_file, 'r') as f: | |
filenames = json.load(f)[split_key] | |
if first_k is not None: | |
filenames = filenames[:first_k] | |
self.files = [] | |
for filename in filenames: | |
file = os.path.join(path, filename) | |
self.append_file(file) | |
def append_file(self, file): | |
if self.cache == 'none': | |
self.files.append(file) | |
elif self.cache == 'in_memory': | |
self.files.append(self.img_process(file)) | |
def __len__(self): | |
return len(self.files) * self.repeat | |
def __getitem__(self, idx): | |
x = self.files[idx % len(self.files)] | |
if self.cache == 'none': | |
return self.img_process(x) | |
elif self.cache == 'in_memory': | |
return x | |
def img_process(self, file): | |
if self.mask: | |
# return Image.open(file).convert('L') | |
return file | |
else: | |
return Image.open(file).convert('RGB') | |
class PairedImageFolders(Dataset): | |
def __init__(self, root_path_1, root_path_2, **kwargs): | |
self.dataset_1 = ImageFolder(root_path_1, **kwargs) | |
self.dataset_2 = ImageFolder(root_path_2, **kwargs, mask=True) | |
def __len__(self): | |
return len(self.dataset_1) | |
def __getitem__(self, idx): | |
return self.dataset_1[idx], self.dataset_2[idx] | |
class ImageFolder_multi_task(Dataset): | |
def __init__(self, path, split_file=None, split_key=None, first_k=None, size=None, | |
repeat=1, cache='none', mask=False): | |
self.repeat = repeat | |
self.cache = cache | |
self.path = path | |
self.Train = False | |
self.split_key = split_key | |
self.size = size | |
self.mask = mask | |
if self.mask: | |
self.img_transform = transforms.Compose([ | |
transforms.Resize((self.size, self.size), interpolation=Image.NEAREST), | |
transforms.ToTensor(), | |
]) | |
else: | |
self.img_transform = transforms.Compose([ | |
transforms.Resize((self.size, self.size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
if split_file is None: | |
filenames = sorted(os.listdir(path)) | |
else: | |
with open(split_file, 'r') as f: | |
filenames = json.load(f)[split_key] | |
if first_k is not None: | |
filenames = filenames[:first_k] | |
self.files = [] | |
for filename in filenames: | |
file = os.path.join(path, filename) | |
self.append_file(file) | |
def append_file(self, file): | |
if self.cache == 'none': | |
self.files.append(file) | |
elif self.cache == 'in_memory': | |
self.files.append(self.img_process(file)) | |
def __len__(self): | |
return len(self.files) * self.repeat | |
def __getitem__(self, idx): | |
x = self.files[idx % len(self.files)] | |
if self.cache == 'none': | |
return self.img_process(x) | |
elif self.cache == 'in_memory': | |
return x | |
def img_process(self, file): | |
if self.mask: | |
# return Image.open(file).convert('L') | |
return file | |
else: | |
return Image.open(file).convert('RGB') | |
class PairedImageFolders_multi_task(Dataset): | |
def __init__(self, root_path_1, root_path_2, model=None, **kwargs): | |
self.dataset_1 = ImageFolder_multi_task(root_path_1, **kwargs) | |
self.dataset_2 = ImageFolder_multi_task(root_path_2, **kwargs, mask=True) | |
def __len__(self): | |
return len(self.dataset_1) | |
def __getitem__(self, idx): | |
return self.dataset_1[idx], self.dataset_2[idx] | |
# class MultiTaskDataset(Dataset): | |
# """ | |
# useage example: | |
# train_datasets = [SemData_Single(), SemData_Single()] | |
# multi_task_train_dataset = MultiTaskDataset(train_datasets) | |
# multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True) | |
# multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler) | |
# for i, (task_id, input, target) in enumerate(multi_task_train_data): | |
# pre = model(input) | |
# """ | |
# def __init__(self, datasets_image, datasets_gt): | |
# self._datasets = datasets_image | |
# task_id_2_image_set_dic = {} | |
# for i, dataset in enumerate(datasets_image): | |
# task_id = i | |
# assert task_id not in task_id_2_image_set_dic, "Duplicate task_id %s" % task_id | |
# task_id_2_image_set_dic[task_id] = dataset | |
# self.datasets_1 = task_id_2_image_set_dic | |
# | |
# task_id_2_gt_set_dic = {} | |
# for i, dataset in enumerate(datasets_gt): | |
# task_id = i | |
# assert task_id not in task_id_2_gt_set_dic, "Duplicate task_id %s" % task_id | |
# task_id_2_gt_set_dic[task_id] = dataset | |
# self.dataset_2 = task_id_2_gt_set_dic | |
# | |
# | |
# def __len__(self): | |
# return sum(len(dataset) for dataset in self._datasets) | |
# | |
# def __getitem__(self, idx): | |
# task_id, sample_id = idx | |
# # return self._task_id_2_data_set_dic[task_id][sample_id] | |
# return self.dataset_1[task_id][sample_id], self.dataset_2[task_id][sample_id] | |
class MultiTaskDataset(Dataset): | |
""" | |
useage example: | |
train_datasets = [SemData_Single(), SemData_Single()] | |
multi_task_train_dataset = MultiTaskDataset(train_datasets) | |
multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True) | |
multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler) | |
for i, (task_id, input, target) in enumerate(multi_task_train_data): | |
pre = model(input) | |
""" | |
def __init__(self, datasets): | |
self._datasets = datasets | |
task_id_2_data_set_dic = {} | |
for i, dataset in enumerate(datasets): | |
task_id = i | |
assert task_id not in task_id_2_data_set_dic, "Duplicate task_id %s" % task_id | |
task_id_2_data_set_dic[task_id] = dataset | |
self._task_id_2_data_set_dic = task_id_2_data_set_dic | |
def __len__(self): | |
return sum(len(dataset) for dataset in self._datasets) | |
def __getitem__(self, idx): | |
task_id, sample_id = idx | |
# print('----', idx, task_id, sample_id) | |
return self._task_id_2_data_set_dic[task_id][sample_id] | |
def collate_fn(batch): | |
# print(len(batch)) | |
# print('*'*10) | |
# print(batch[0][0]) | |
# print('#'*10) | |
# print(batch[0][1]) | |
# batch = list(filter(lambda x: x[0][0] is not None, batch)) | |
# if len(batch) == 0: return torch.Tensor() | |
print('******------',batch) | |
if not isinstance(batch[0][0], tuple): | |
return default_collate(batch) | |
else: | |
batch_num = len(batch) | |
ret = [] | |
for item_idx in range(len(batch[0][0])): | |
if batch[0][0][item_idx] is None: | |
ret.append(None) | |
else: | |
ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)])) | |
ret.append(default_collate([batch[i][1] for i in range(batch_num)])) | |
return ret | |
class DistrubutedMultiTaskBatchSampler(BatchSampler): | |
""" | |
datasets: class the class of the Dataset | |
batch_size: int | |
mix_opt: int mix_opt ==0 shuffle all_task; mix_opt ==1 shuffle extra_task | |
extra_task_ratio(float, optional): the rate between task one and extra task | |
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, | |
if the dataset size is not divisible by the batch size. If ``False`` and | |
the size of dataset is not divisible by the batch size, then the last batch | |
will be smaller. (default: ``True``) | |
""" | |
def __init__(self, datasets, batch_size, num_replicas, rank, mix_opt=0, extra_task_ratio=0, drop_last=True, | |
shuffle=True): | |
if num_replicas is None: | |
if not dist.is_available(): | |
raise RuntimeError("Requires distributed package to be available") | |
num_replicas = dist.get_world_size() | |
if rank is None: | |
if not dist.is_available(): | |
raise RuntimeError("Requires distributed package to be available") | |
rank = dist.get_rank() | |
if rank >= num_replicas or rank < 0: | |
raise ValueError( | |
"Invalid rank {}, rank should be in the interval" | |
" [0, {}]".format(rank, num_replicas - 1)) | |
self.num_replicas = num_replicas | |
self.rank = rank | |
self.epoch = 0 | |
assert mix_opt in [0, 1], 'mix_opt must equal 0 or 1' | |
assert extra_task_ratio >= 0, 'extra_task_ratio must greater than 0' | |
# self._datasets = datasets | |
self._batch_size = batch_size | |
self._mix_opt = mix_opt | |
self._extra_task_ratio = extra_task_ratio | |
self._drop_last = drop_last | |
train_data_list = [] | |
self.shuffle = shuffle | |
for dataset in datasets: | |
print(len(dataset)) | |
train_data_list.append(self._get_index_batches(len(dataset), batch_size, self._drop_last)) | |
######### 一个列表里存n个dataset的数据,数据也以列表形式存在,一个dataset的列表里面把数据划分成了不同的batch的index | |
self._train_data_list = train_data_list | |
self.total_len = sum(len(train_data) for train_data in self._train_data_list) | |
######### DDP ###################### | |
if self._drop_last and self.total_len % self.num_replicas != 0: # type: ignore[arg-type] | |
self.num_samples = math.ceil( | |
(self.total_len - self.num_replicas) / self.num_replicas # type: ignore[arg-type] | |
) | |
else: | |
self.num_samples = math.ceil(self.total_len / self.num_replicas) # type: ignore[arg-type] | |
self.total_size = self.num_samples * self.num_replicas | |
self.epoch = 0 | |
self.seed = 0 | |
def set_epoch(self, epoch): | |
# print('&&&&****') | |
self.epoch = epoch | |
def _get_index_batches(dataset_len, batch_size, drop_last): | |
# index_batches = [list(range(i, min(i+batch_size, dataset_len))) for i in range(0, dataset_len, batch_size)] | |
index = list(range(dataset_len)) | |
if drop_last and dataset_len % batch_size: | |
del index[-(dataset_len % batch_size):] | |
index_batches = [index[i:i + batch_size] for i in range(0, len(index), batch_size)] | |
return index_batches | |
def __len__(self): | |
# return sum(len(train_data) for train_data in self._train_data_list) | |
return self.num_samples | |
def __iter__(self): | |
all_iters = [iter(item) for item in self._train_data_list] | |
all_indices = self._gen_task_indices(self._train_data_list, self._mix_opt, self._extra_task_ratio) | |
######### DDP ###################### | |
random.shuffle(all_indices) | |
all_indices = all_indices[self.rank:self.total_size:self.num_replicas] | |
assert len(all_indices) == self.num_samples | |
for local_task_idx in all_indices: | |
# task_id = self._datasets[local_task_idx].get_task_id() | |
batch = next(all_iters[local_task_idx]) | |
# batch = batch[self.rank:len(batch):self.num_replicas] | |
# print(local_task_idx) | |
yield [(local_task_idx, sample_id) for sample_id in batch] | |
# yield iter(batch) | |
def _gen_task_indices(train_data_list, mix_opt, extra_task_ratio): | |
########## accoding to the number of models ########### | |
all_indices = [] | |
for i in range(len(train_data_list)): | |
all_indices += [i] * len(train_data_list[i]) | |
# print(all_indices) | |
return all_indices | |
# def set_epoch(self, epoch) | |
# r""" | |
# Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas | |
# use a different random ordering for each epoch. Otherwise, the next iteration of this | |
# sampler will yield the same ordering. | |
# Args: | |
# epoch (int): Epoch number. | |
# """ | |
# self.epoch = epoch | |