|
import torch |
|
|
|
|
|
class Sampler(object): |
|
"""Base class for all Samplers. |
|
|
|
Every Sampler subclass has to provide an __iter__ method, providing a way |
|
to iterate over indices of dataset elements, and a __len__ method that |
|
returns the length of the returned iterators. |
|
""" |
|
|
|
def __init__(self, data_source): |
|
pass |
|
|
|
def __iter__(self): |
|
raise NotImplementedError |
|
|
|
def __len__(self): |
|
raise NotImplementedError |
|
|
|
|
|
class SequentialSampler(Sampler): |
|
"""Samples elements sequentially, always in the same order. |
|
|
|
Arguments: |
|
data_source (Dataset): dataset to sample from |
|
""" |
|
|
|
def __init__(self, data_source): |
|
self.data_source = data_source |
|
|
|
def __iter__(self): |
|
return iter(range(len(self.data_source))) |
|
|
|
def __len__(self): |
|
return len(self.data_source) |
|
|
|
|
|
class RandomSampler(Sampler): |
|
"""Samples elements randomly, without replacement. |
|
|
|
Arguments: |
|
data_source (Dataset): dataset to sample from |
|
""" |
|
|
|
def __init__(self, data_source): |
|
self.data_source = data_source |
|
|
|
def __iter__(self): |
|
return iter(torch.randperm(len(self.data_source)).long()) |
|
|
|
def __len__(self): |
|
return len(self.data_source) |
|
|
|
|
|
class SubsetRandomSampler(Sampler): |
|
"""Samples elements randomly from a given list of indices, without replacement. |
|
|
|
Arguments: |
|
indices (list): a list of indices |
|
""" |
|
|
|
def __init__(self, indices): |
|
self.indices = indices |
|
|
|
def __iter__(self): |
|
return (self.indices[i] for i in torch.randperm(len(self.indices))) |
|
|
|
def __len__(self): |
|
return len(self.indices) |
|
|
|
|
|
class WeightedRandomSampler(Sampler): |
|
"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights). |
|
|
|
Arguments: |
|
weights (list) : a list of weights, not necessary summing up to one |
|
num_samples (int): number of samples to draw |
|
replacement (bool): if ``True``, samples are drawn with replacement. |
|
If not, they are drawn without replacement, which means that when a |
|
sample index is drawn for a row, it cannot be drawn again for that row. |
|
""" |
|
|
|
def __init__(self, weights, num_samples, replacement=True): |
|
self.weights = torch.DoubleTensor(weights) |
|
self.num_samples = num_samples |
|
self.replacement = replacement |
|
|
|
def __iter__(self): |
|
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) |
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|
|
|
|
class BatchSampler(object): |
|
"""Wraps another sampler to yield a mini-batch of indices. |
|
|
|
Args: |
|
sampler (Sampler): Base sampler. |
|
batch_size (int): Size of mini-batch. |
|
drop_last (bool): If ``True``, the sampler will drop the last batch if |
|
its size would be less than ``batch_size`` |
|
|
|
Example: |
|
>>> list(BatchSampler(range(10), batch_size=3, drop_last=False)) |
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] |
|
>>> list(BatchSampler(range(10), batch_size=3, drop_last=True)) |
|
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] |
|
""" |
|
|
|
def __init__(self, sampler, batch_size, drop_last): |
|
self.sampler = sampler |
|
self.batch_size = batch_size |
|
self.drop_last = drop_last |
|
|
|
def __iter__(self): |
|
batch = [] |
|
for idx in self.sampler: |
|
batch.append(idx) |
|
if len(batch) == self.batch_size: |
|
yield batch |
|
batch = [] |
|
if len(batch) > 0 and not self.drop_last: |
|
yield batch |
|
|
|
def __len__(self): |
|
if self.drop_last: |
|
return len(self.sampler) // self.batch_size |
|
else: |
|
return (len(self.sampler) + self.batch_size - 1) // self.batch_size |
|
|