# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Dataloaders.""" import torch class PretrainingSampler: def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, drop_last=True): # Keep a copy of input params for later use. self.total_samples = total_samples self.consumed_samples = consumed_samples self.micro_batch_size = micro_batch_size self.data_parallel_rank = data_parallel_rank self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size self.drop_last = drop_last # Sanity checks. assert self.total_samples > 0, \ 'no sample to consume: {}'.format(self.total_samples) assert self.consumed_samples < self.total_samples, \ 'no samples left to consume: {}, {}'.format(self.consumed_samples, self.total_samples) assert self.micro_batch_size > 0 assert data_parallel_size > 0 assert self.data_parallel_rank < data_parallel_size, \ 'data_parallel_rank should be smaller than data size: {}, ' \ '{}'.format(self.data_parallel_rank, data_parallel_size) def __len__(self): return self.total_samples // self.micro_batch_times_data_parallel_size def get_start_end_idx(self): start_idx = self.data_parallel_rank * self.micro_batch_size end_idx = start_idx + self.micro_batch_size return start_idx, end_idx def __iter__(self): batch = [] # Last batch will be dropped if drop_last is not set False for idx in range(self.consumed_samples, self.total_samples): batch.append(idx) if len(batch) == self.micro_batch_times_data_parallel_size: start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] batch = [] # Check the last partial batch and see drop_last is set if len(batch) > 0 and not self.drop_last: start_idx, end_idx = self.get_start_end_idx() yield batch[start_idx:end_idx] class PretrainingRandomSampler: def __init__(self, total_samples, consumed_samples, micro_batch_size, data_parallel_rank, data_parallel_size, epoch): # Keep a copy of input params for later use. self.total_samples = total_samples self.consumed_samples = consumed_samples self.micro_batch_size = micro_batch_size self.data_parallel_rank = data_parallel_rank self.data_parallel_size = data_parallel_size self.micro_batch_times_data_parallel_size = \ self.micro_batch_size * data_parallel_size self.last_batch_size = \ self.total_samples % self.micro_batch_times_data_parallel_size self.epoch = epoch # Sanity checks. assert self.total_samples > 0, \ 'no sample to consume: {}'.format(self.total_samples) assert self.micro_batch_size > 0 assert data_parallel_size > 0 assert self.data_parallel_rank < data_parallel_size, \ 'data_parallel_rank should be smaller than data size: {}, ' \ '{}'.format(self.data_parallel_rank, data_parallel_size) def __len__(self): return self.total_samples // self.micro_batch_times_data_parallel_size def __iter__(self): active_total_samples = self.total_samples - self.last_batch_size current_epoch_samples = self.consumed_samples % active_total_samples assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 # data sharding and random sampling bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) \ * self.micro_batch_size bucket_offset = current_epoch_samples // self.data_parallel_size start_idx = self.data_parallel_rank * bucket_size g = torch.Generator() g.manual_seed(self.epoch) random_idx = torch.randperm(bucket_size, generator=g).tolist() idx_range = [start_idx + x for x in random_idx[bucket_offset:]] batch = [] # Last batch if not complete will be dropped. for idx in idx_range: batch.append(idx) if len(batch) == self.micro_batch_size: self.consumed_samples += self.micro_batch_times_data_parallel_size yield batch batch = [] def set_epoch(self, epoch): self.epoch = epoch