File size: 2,561 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
#!/usr/bin/env python
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""pytorch dataset and dataloader implementation for chainer training."""
import torch
import torch.utils.data
class TransformDataset(torch.utils.data.Dataset):
"""Transform Dataset for pytorch backend.
Args:
data: list object from make_batchset
transfrom: transform function
"""
def __init__(self, data, transform):
"""Init function."""
super(TransformDataset).__init__()
self.data = data
self.transform = transform
def __len__(self):
"""Len function."""
return len(self.data)
def __getitem__(self, idx):
"""[] operator."""
return self.transform(self.data[idx])
class ChainerDataLoader(object):
"""Pytorch dataloader in chainer style.
Args:
all args for torch.utils.data.dataloader.Dataloader
"""
def __init__(self, **kwargs):
"""Init function."""
self.loader = torch.utils.data.dataloader.DataLoader(**kwargs)
self.len = len(kwargs["dataset"])
self.current_position = 0
self.epoch = 0
self.iter = None
self.kwargs = kwargs
def next(self):
"""Implement next function."""
if self.iter is None:
self.iter = iter(self.loader)
try:
ret = next(self.iter)
except StopIteration:
self.iter = None
return self.next()
self.current_position += 1
if self.current_position == self.len:
self.epoch = self.epoch + 1
self.current_position = 0
return ret
def __iter__(self):
"""Implement iter function."""
for batch in self.loader:
yield batch
@property
def epoch_detail(self):
"""Epoch_detail required by chainer."""
return self.epoch + self.current_position / self.len
def serialize(self, serializer):
"""Serialize and deserialize function."""
epoch = serializer("epoch", self.epoch)
current_position = serializer("current_position", self.current_position)
self.epoch = epoch
self.current_position = current_position
def start_shuffle(self):
"""Shuffle function for sortagrad."""
self.kwargs["shuffle"] = True
self.loader = torch.utils.data.dataloader.DataLoader(**self.kwargs)
def finalize(self):
"""Implement finalize function."""
del self.loader
|