OMG_Seg / seg /datasets /samplers /batch_sampler.py
Haobo Yuan
add omg code
b34d1d6
raw
history blame
6.16 kB
from typing import Sequence
import torch
import torch.distributed as torch_dist
from mmengine.dist import get_dist_info, get_default_group, get_comm_device
from torch._C._distributed_c10d import ReduceOp
from torch.utils.data import Sampler, BatchSampler
from mmdet.datasets.samplers.batch_sampler import AspectRatioBatchSampler
from mmdet.registry import DATA_SAMPLERS
@DATA_SAMPLERS.register_module()
class VideoSegAspectRatioBatchSampler(AspectRatioBatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __iter__(self) -> Sequence[int]:
for idx in self.sampler:
# hard code to solve TrackImgSampler
video_idx = idx
# video_idx
data_info = self.sampler.dataset.get_data_info(video_idx)
# data_info {video_id, images, video_length}
if 'images' in data_info:
img_data_info = data_info['images'][0]
else:
img_data_info = data_info
width, height = img_data_info['width'], img_data_info['height']
bucket_id = 0 if width < height else 1
bucket = self._aspect_ratio_buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
# yield the rest data and reset the bucket
left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[
1]
self._aspect_ratio_buckets = [[] for _ in range(2)]
while len(left_data) > 0:
if len(left_data) <= self.batch_size:
if not self.drop_last:
yield left_data[:]
left_data = []
else:
yield left_data[:self.batch_size]
left_data = left_data[self.batch_size:]
@DATA_SAMPLERS.register_module()
class MultiDataAspectRatioBatchSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio (< 1 or.
>= 1) into a same batch for multi-source datasets.
Args:
sampler (Sampler): Base sampler.
batch_size (Sequence(int)): Size of mini-batch for multi-source
datasets.
num_datasets(int): Number of multi-source datasets.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
"""
def __init__(self,
sampler: Sampler,
batch_size: Sequence[int],
num_datasets: int,
drop_last: bool = True) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
self.sampler = sampler
if isinstance(batch_size, int):
self.batch_size = [batch_size] * num_datasets
else:
self.batch_size = batch_size
self.num_datasets = num_datasets
self.drop_last = drop_last
# two groups for w < h and w >= h for each dataset --> 2 * num_datasets
self._buckets = [[] for _ in range(2 * self.num_datasets)]
def __iter__(self) -> Sequence[int]:
num_batch = torch.tensor(len(self), device='cpu')
rank, world_size = get_dist_info()
if world_size > 1:
group = get_default_group()
backend_device = get_comm_device(group)
num_batch = num_batch.to(device=backend_device)
torch_dist.all_reduce(num_batch, op=ReduceOp.MIN, group=group)
num_batch = num_batch.to('cpu').item()
for idx in self.sampler:
data_info = self.sampler.dataset.get_data_info(idx)
width, height = data_info.get('width', 0), data_info.get('height', 0)
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx)
aspect_ratio_bucket_id = 0 if width < height else 1
bucket_id = dataset_source_idx * 2 + aspect_ratio_bucket_id
bucket = self._buckets[bucket_id]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size[dataset_source_idx]:
yield bucket[:]
num_batch -= 1
if num_batch <= 0:
return
del bucket[:]
# yield the rest data and reset the bucket
for i in range(self.num_datasets):
left_data = self._buckets[i * 2 + 0] + self._buckets[i * 2 + 1]
while len(left_data) > 0:
if len(left_data) < self.batch_size[i]:
if not self.drop_last:
yield left_data[:]
num_batch -= 1
if num_batch <= 0:
return
left_data = []
else:
yield left_data[:self.batch_size[i]]
num_batch -= 1
if num_batch <= 0:
return
left_data = left_data[self.batch_size[i]:]
self._buckets = [[] for _ in range(2 * self.num_datasets)]
def __len__(self) -> int:
sizes = [0 for _ in range(self.num_datasets)]
for idx in self.sampler:
dataset_source_idx = self.sampler.dataset.get_dataset_source(idx)
sizes[dataset_source_idx] += 1
if self.drop_last:
lens = 0
for i in range(self.num_datasets):
lens += sizes[i] // self.batch_size[i]
return lens
else:
lens = 0
for i in range(self.num_datasets):
lens += (sizes[i] + self.batch_size[i] - 1) // self.batch_size[i]
return lens