RingMo-SAM / datasets /image_folder.py
AI-Cyber's picture
Upload 123 files
8d7921b
raw
history blame contribute delete
13.8 kB
import os
import json
from PIL import Image
import pickle
import imageio
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import random
from datasets import register
import math
import torch.distributed as dist
from torch.utils.data import BatchSampler
from torch.utils.data._utils.collate import default_collate
@register('image-folder')
class ImageFolder(Dataset):
def __init__(self, path, split_file=None, split_key=None, first_k=None, size=None,
repeat=1, cache='none', mask=False):
self.repeat = repeat
self.cache = cache
self.path = path
self.Train = False
self.split_key = split_key
self.size = size
self.mask = mask
if self.mask:
self.img_transform = transforms.Compose([
transforms.Resize((self.size, self.size), interpolation=Image.NEAREST),
transforms.ToTensor(),
])
else:
self.img_transform = transforms.Compose([
transforms.Resize((self.size, self.size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
if split_file is None:
filenames = sorted(os.listdir(path))
else:
with open(split_file, 'r') as f:
filenames = json.load(f)[split_key]
if first_k is not None:
filenames = filenames[:first_k]
self.files = []
for filename in filenames:
file = os.path.join(path, filename)
self.append_file(file)
def append_file(self, file):
if self.cache == 'none':
self.files.append(file)
elif self.cache == 'in_memory':
self.files.append(self.img_process(file))
def __len__(self):
return len(self.files) * self.repeat
def __getitem__(self, idx):
x = self.files[idx % len(self.files)]
if self.cache == 'none':
return self.img_process(x)
elif self.cache == 'in_memory':
return x
def img_process(self, file):
if self.mask:
# return Image.open(file).convert('L')
return file
else:
return Image.open(file).convert('RGB')
@register('paired-image-folders')
class PairedImageFolders(Dataset):
def __init__(self, root_path_1, root_path_2, **kwargs):
self.dataset_1 = ImageFolder(root_path_1, **kwargs)
self.dataset_2 = ImageFolder(root_path_2, **kwargs, mask=True)
def __len__(self):
return len(self.dataset_1)
def __getitem__(self, idx):
return self.dataset_1[idx], self.dataset_2[idx]
class ImageFolder_multi_task(Dataset):
def __init__(self, path, split_file=None, split_key=None, first_k=None, size=None,
repeat=1, cache='none', mask=False):
self.repeat = repeat
self.cache = cache
self.path = path
self.Train = False
self.split_key = split_key
self.size = size
self.mask = mask
if self.mask:
self.img_transform = transforms.Compose([
transforms.Resize((self.size, self.size), interpolation=Image.NEAREST),
transforms.ToTensor(),
])
else:
self.img_transform = transforms.Compose([
transforms.Resize((self.size, self.size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
if split_file is None:
filenames = sorted(os.listdir(path))
else:
with open(split_file, 'r') as f:
filenames = json.load(f)[split_key]
if first_k is not None:
filenames = filenames[:first_k]
self.files = []
for filename in filenames:
file = os.path.join(path, filename)
self.append_file(file)
def append_file(self, file):
if self.cache == 'none':
self.files.append(file)
elif self.cache == 'in_memory':
self.files.append(self.img_process(file))
def __len__(self):
return len(self.files) * self.repeat
def __getitem__(self, idx):
x = self.files[idx % len(self.files)]
if self.cache == 'none':
return self.img_process(x)
elif self.cache == 'in_memory':
return x
def img_process(self, file):
if self.mask:
# return Image.open(file).convert('L')
return file
else:
return Image.open(file).convert('RGB')
@register('paired-image-folders-multi-task')
class PairedImageFolders_multi_task(Dataset):
def __init__(self, root_path_1, root_path_2, model=None, **kwargs):
self.dataset_1 = ImageFolder_multi_task(root_path_1, **kwargs)
self.dataset_2 = ImageFolder_multi_task(root_path_2, **kwargs, mask=True)
def __len__(self):
return len(self.dataset_1)
def __getitem__(self, idx):
return self.dataset_1[idx], self.dataset_2[idx]
# class MultiTaskDataset(Dataset):
# """
# useage example:
# train_datasets = [SemData_Single(), SemData_Single()]
# multi_task_train_dataset = MultiTaskDataset(train_datasets)
# multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True)
# multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler)
# for i, (task_id, input, target) in enumerate(multi_task_train_data):
# pre = model(input)
# """
# def __init__(self, datasets_image, datasets_gt):
# self._datasets = datasets_image
# task_id_2_image_set_dic = {}
# for i, dataset in enumerate(datasets_image):
# task_id = i
# assert task_id not in task_id_2_image_set_dic, "Duplicate task_id %s" % task_id
# task_id_2_image_set_dic[task_id] = dataset
# self.datasets_1 = task_id_2_image_set_dic
#
# task_id_2_gt_set_dic = {}
# for i, dataset in enumerate(datasets_gt):
# task_id = i
# assert task_id not in task_id_2_gt_set_dic, "Duplicate task_id %s" % task_id
# task_id_2_gt_set_dic[task_id] = dataset
# self.dataset_2 = task_id_2_gt_set_dic
#
#
# def __len__(self):
# return sum(len(dataset) for dataset in self._datasets)
#
# def __getitem__(self, idx):
# task_id, sample_id = idx
# # return self._task_id_2_data_set_dic[task_id][sample_id]
# return self.dataset_1[task_id][sample_id], self.dataset_2[task_id][sample_id]
class MultiTaskDataset(Dataset):
"""
useage example:
train_datasets = [SemData_Single(), SemData_Single()]
multi_task_train_dataset = MultiTaskDataset(train_datasets)
multi_task_batch_sampler = MultiTaskBatchSampler(train_datasets, batch_size=4, mix_opt=0, extra_task_ratio=0, drop_last=True)
multi_task_train_data = DataLoader(multi_task_train_dataset, batch_sampler=multi_task_batch_sampler)
for i, (task_id, input, target) in enumerate(multi_task_train_data):
pre = model(input)
"""
def __init__(self, datasets):
self._datasets = datasets
task_id_2_data_set_dic = {}
for i, dataset in enumerate(datasets):
task_id = i
assert task_id not in task_id_2_data_set_dic, "Duplicate task_id %s" % task_id
task_id_2_data_set_dic[task_id] = dataset
self._task_id_2_data_set_dic = task_id_2_data_set_dic
def __len__(self):
return sum(len(dataset) for dataset in self._datasets)
def __getitem__(self, idx):
task_id, sample_id = idx
# print('----', idx, task_id, sample_id)
return self._task_id_2_data_set_dic[task_id][sample_id]
def collate_fn(batch):
# print(len(batch))
# print('*'*10)
# print(batch[0][0])
# print('#'*10)
# print(batch[0][1])
# batch = list(filter(lambda x: x[0][0] is not None, batch))
# if len(batch) == 0: return torch.Tensor()
print('******------',batch)
if not isinstance(batch[0][0], tuple):
return default_collate(batch)
else:
batch_num = len(batch)
ret = []
for item_idx in range(len(batch[0][0])):
if batch[0][0][item_idx] is None:
ret.append(None)
else:
ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
return ret
class DistrubutedMultiTaskBatchSampler(BatchSampler):
"""
datasets: class the class of the Dataset
batch_size: int
mix_opt: int mix_opt ==0 shuffle all_task; mix_opt ==1 shuffle extra_task
extra_task_ratio(float, optional): the rate between task one and extra task
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: ``True``)
"""
def __init__(self, datasets, batch_size, num_replicas, rank, mix_opt=0, extra_task_ratio=0, drop_last=True,
shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if rank >= num_replicas or rank < 0:
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1))
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
assert mix_opt in [0, 1], 'mix_opt must equal 0 or 1'
assert extra_task_ratio >= 0, 'extra_task_ratio must greater than 0'
# self._datasets = datasets
self._batch_size = batch_size
self._mix_opt = mix_opt
self._extra_task_ratio = extra_task_ratio
self._drop_last = drop_last
train_data_list = []
self.shuffle = shuffle
for dataset in datasets:
print(len(dataset))
train_data_list.append(self._get_index_batches(len(dataset), batch_size, self._drop_last))
######### 一个列表里存n个dataset的数据,数据也以列表形式存在,一个dataset的列表里面把数据划分成了不同的batch的index
self._train_data_list = train_data_list
self.total_len = sum(len(train_data) for train_data in self._train_data_list)
######### DDP ######################
if self._drop_last and self.total_len % self.num_replicas != 0: # type: ignore[arg-type]
self.num_samples = math.ceil(
(self.total_len - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(self.total_len / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.epoch = 0
self.seed = 0
def set_epoch(self, epoch):
# print('&&&&****')
self.epoch = epoch
@staticmethod
def _get_index_batches(dataset_len, batch_size, drop_last):
# index_batches = [list(range(i, min(i+batch_size, dataset_len))) for i in range(0, dataset_len, batch_size)]
index = list(range(dataset_len))
if drop_last and dataset_len % batch_size:
del index[-(dataset_len % batch_size):]
index_batches = [index[i:i + batch_size] for i in range(0, len(index), batch_size)]
return index_batches
def __len__(self):
# return sum(len(train_data) for train_data in self._train_data_list)
return self.num_samples
def __iter__(self):
all_iters = [iter(item) for item in self._train_data_list]
all_indices = self._gen_task_indices(self._train_data_list, self._mix_opt, self._extra_task_ratio)
######### DDP ######################
random.shuffle(all_indices)
all_indices = all_indices[self.rank:self.total_size:self.num_replicas]
assert len(all_indices) == self.num_samples
for local_task_idx in all_indices:
# task_id = self._datasets[local_task_idx].get_task_id()
batch = next(all_iters[local_task_idx])
# batch = batch[self.rank:len(batch):self.num_replicas]
# print(local_task_idx)
yield [(local_task_idx, sample_id) for sample_id in batch]
# yield iter(batch)
@staticmethod
def _gen_task_indices(train_data_list, mix_opt, extra_task_ratio):
########## accoding to the number of models ###########
all_indices = []
for i in range(len(train_data_list)):
all_indices += [i] * len(train_data_list[i])
# print(all_indices)
return all_indices
# def set_epoch(self, epoch)
# r"""
# Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch. Otherwise, the next iteration of this
# sampler will yield the same ordering.
# Args:
# epoch (int): Epoch number.
# """
# self.epoch = epoch