|
from __future__ import division |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.sampler import Sampler |
|
|
|
|
|
class RandomSampler(Sampler): |
|
def __init__(self, data_source, checkpoint): |
|
self.data_source = data_source |
|
if checkpoint is not None and checkpoint['dataset_perm'] is not None: |
|
self.dataset_perm = checkpoint['dataset_perm'] |
|
self.perm = self.dataset_perm[checkpoint['batch_size'] * checkpoint['batch_idx']:] |
|
else: |
|
self.dataset_perm = torch.randperm(len(self.data_source)).tolist() |
|
self.perm = torch.randperm(len(self.data_source)).tolist() |
|
|
|
def __iter__(self): |
|
return iter(self.perm) |
|
|
|
def __len__(self): |
|
return len(self.perm) |
|
|
|
|
|
class SequentialSampler(Sampler): |
|
def __init__(self, data_source, checkpoint): |
|
self.data_source = data_source |
|
if checkpoint is not None and checkpoint['dataset_perm'] is not None: |
|
self.dataset_perm = checkpoint['dataset_perm'] |
|
self.perm = self.dataset_perm[checkpoint['batch_size'] * checkpoint['batch_idx']:] |
|
else: |
|
self.dataset_perm = list(range(len(self.data_source))) |
|
self.perm = self.dataset_perm |
|
|
|
def __iter__(self): |
|
return iter(self.perm) |
|
|
|
def __len__(self): |
|
return len(self.perm) |
|
|
|
|
|
class CheckpointDataLoader(DataLoader): |
|
""" |
|
Extends torch.utils.data.DataLoader to handle resuming training from an arbitrary point within an epoch. |
|
""" |
|
def __init__( |
|
self, |
|
dataset, |
|
checkpoint=None, |
|
batch_size=1, |
|
shuffle=False, |
|
num_workers=0, |
|
pin_memory=False, |
|
drop_last=True, |
|
timeout=0, |
|
worker_init_fn=None |
|
): |
|
|
|
if shuffle: |
|
sampler = RandomSampler(dataset, checkpoint) |
|
else: |
|
sampler = SequentialSampler(dataset, checkpoint) |
|
if checkpoint is not None: |
|
self.checkpoint_batch_idx = checkpoint['batch_idx'] |
|
else: |
|
self.checkpoint_batch_idx = 0 |
|
|
|
super(CheckpointDataLoader, self).__init__( |
|
dataset, |
|
sampler=sampler, |
|
shuffle=False, |
|
batch_size=batch_size, |
|
num_workers=num_workers, |
|
drop_last=drop_last, |
|
pin_memory=pin_memory, |
|
timeout=timeout, |
|
worker_init_fn=None |
|
) |
|
|