import math from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler def identity(x): return x class SortedSampler(Sampler): """ Samples elements sequentially, always in the same order. Args: data (`obj`: `Iterable`): Iterable data. sort_key (`obj`: `Callable`): Specifies a function of one argument that is used to extract a numerical comparison key from each list element. Example: >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] """ def __init__(self, data, sort_key=identity): super().__init__(data) self.data = data self.sort_key = sort_key zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] zip_ = sorted(zip_, key=lambda r: r[1]) self.sorted_indexes = [item[0] for item in zip_] def __iter__(self): return iter(self.sorted_indexes) def __len__(self): return len(self.data) class BucketBatchSampler(BatchSampler): """ `BucketBatchSampler` toggles between `sampler` batches and sorted batches. Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted and vice versa. Background: ``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular libraries like ``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together examples with a similar size length to reduce the padding required for each batch while maintaining some noise through bucketing. **AllenNLP Implementation:** https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/bucket_iterator.py **torchtext Implementation:** https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225 Args: sampler (`obj`: `torch.data.utils.sampler.Sampler): batch_size (`int`): Size of mini-batch. drop_last (`bool`, optional, defaults to `False`): If `True` the sampler will drop the last batch if its size would be less than `batch_size`. sort_key (`obj`: `Callable`, optional, defaults to `identity`): Callable to specify a comparison key for sorting. bucket_size_multiplier (`int`, optional, defaults to `100`): Buckets are of size `batch_size * bucket_size_multiplier`. Example: >>> from torchnlp.random import set_seed >>> set_seed(123) >>> >>> from torch.utils.data.sampler import SequentialSampler >>> sampler = SequentialSampler(list(range(10))) >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False)) [[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]] >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] """ def __init__( self, sampler, batch_size, drop_last: bool = False, sort_key=identity, bucket_size_multiplier=100, ): super().__init__(sampler, batch_size, drop_last) self.sort_key = sort_key _bucket_size = batch_size * bucket_size_multiplier if hasattr(sampler, "__len__"): _bucket_size = min(_bucket_size, len(sampler)) self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) def __iter__(self): for bucket in self.bucket_sampler: sorted_sampler = SortedSampler(bucket, self.sort_key) for batch in SubsetRandomSampler( list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last)) ): yield [bucket[i] for i in batch] def __len__(self): if self.drop_last: return len(self.sampler) // self.batch_size else: return math.ceil(len(self.sampler) / self.batch_size)