|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Dataloaders.""" |
|
|
|
|
|
import torch |
|
|
|
|
|
class PretrainingSampler: |
|
|
|
def __init__(self, total_samples, consumed_samples, micro_batch_size, |
|
data_parallel_rank, data_parallel_size, drop_last=True): |
|
|
|
self.total_samples = total_samples |
|
self.consumed_samples = consumed_samples |
|
self.micro_batch_size = micro_batch_size |
|
self.data_parallel_rank = data_parallel_rank |
|
self.micro_batch_times_data_parallel_size = \ |
|
self.micro_batch_size * data_parallel_size |
|
self.drop_last = drop_last |
|
|
|
|
|
assert self.total_samples > 0, \ |
|
'no sample to consume: {}'.format(self.total_samples) |
|
assert self.consumed_samples < self.total_samples, \ |
|
'no samples left to consume: {}, {}'.format(self.consumed_samples, |
|
self.total_samples) |
|
assert self.micro_batch_size > 0 |
|
assert data_parallel_size > 0 |
|
assert self.data_parallel_rank < data_parallel_size, \ |
|
'data_parallel_rank should be smaller than data size: {}, ' \ |
|
'{}'.format(self.data_parallel_rank, data_parallel_size) |
|
|
|
def __len__(self): |
|
return self.total_samples // self.micro_batch_times_data_parallel_size |
|
|
|
def get_start_end_idx(self): |
|
start_idx = self.data_parallel_rank * self.micro_batch_size |
|
end_idx = start_idx + self.micro_batch_size |
|
return start_idx, end_idx |
|
|
|
def __iter__(self): |
|
batch = [] |
|
|
|
for idx in range(self.consumed_samples, self.total_samples): |
|
batch.append(idx) |
|
if len(batch) == self.micro_batch_times_data_parallel_size: |
|
start_idx, end_idx = self.get_start_end_idx() |
|
yield batch[start_idx:end_idx] |
|
batch = [] |
|
|
|
|
|
if len(batch) > 0 and not self.drop_last: |
|
start_idx, end_idx = self.get_start_end_idx() |
|
yield batch[start_idx:end_idx] |
|
|
|
|
|
class PretrainingRandomSampler: |
|
|
|
def __init__(self, total_samples, consumed_samples, micro_batch_size, |
|
data_parallel_rank, data_parallel_size, epoch): |
|
|
|
self.total_samples = total_samples |
|
self.consumed_samples = consumed_samples |
|
self.micro_batch_size = micro_batch_size |
|
self.data_parallel_rank = data_parallel_rank |
|
self.data_parallel_size = data_parallel_size |
|
self.micro_batch_times_data_parallel_size = \ |
|
self.micro_batch_size * data_parallel_size |
|
self.last_batch_size = \ |
|
self.total_samples % self.micro_batch_times_data_parallel_size |
|
self.epoch = epoch |
|
|
|
|
|
assert self.total_samples > 0, \ |
|
'no sample to consume: {}'.format(self.total_samples) |
|
assert self.micro_batch_size > 0 |
|
assert data_parallel_size > 0 |
|
assert self.data_parallel_rank < data_parallel_size, \ |
|
'data_parallel_rank should be smaller than data size: {}, ' \ |
|
'{}'.format(self.data_parallel_rank, data_parallel_size) |
|
|
|
def __len__(self): |
|
return self.total_samples // self.micro_batch_times_data_parallel_size |
|
|
|
def __iter__(self): |
|
active_total_samples = self.total_samples - self.last_batch_size |
|
current_epoch_samples = self.consumed_samples % active_total_samples |
|
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 |
|
|
|
|
|
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ |
|
* self.micro_batch_size |
|
bucket_offset = current_epoch_samples // self.data_parallel_size |
|
start_idx = self.data_parallel_rank * bucket_size |
|
|
|
g = torch.Generator() |
|
g.manual_seed(self.epoch) |
|
random_idx = torch.randperm(bucket_size, generator=g).tolist() |
|
idx_range = [start_idx + x for x in random_idx[bucket_offset:]] |
|
|
|
batch = [] |
|
|
|
for idx in idx_range: |
|
batch.append(idx) |
|
if len(batch) == self.micro_batch_size: |
|
self.consumed_samples += self.micro_batch_times_data_parallel_size |
|
yield batch |
|
batch = [] |
|
|
|
def set_epoch(self, epoch): |
|
self.epoch = epoch |
|
|