baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
2.19 kB
import numpy as np
import torch
from torch.utils.data import Dataset, Sampler
# TODO: move this to more reasonable place
from rlkit.data_management.obs_dict_replay_buffer import normalize_image
class ImageDataset(Dataset):
def __init__(self, images, should_normalize=True):
super().__init__()
self.dataset = images
self.dataset_len = len(self.dataset)
assert should_normalize == (images.dtype == np.uint8)
self.should_normalize = should_normalize
def __len__(self):
return self.dataset_len
def __getitem__(self, idxs):
samples = self.dataset[idxs, :]
if self.should_normalize:
samples = normalize_image(samples)
return np.float32(samples)
class InfiniteRandomSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
self.iter = iter(torch.randperm(len(self.data_source)).tolist())
def __iter__(self):
return self
def __next__(self):
try:
idx = next(self.iter)
except StopIteration:
self.iter = iter(torch.randperm(len(self.data_source)).tolist())
idx = next(self.iter)
return idx
def __len__(self):
return 2 ** 62
class InfiniteWeightedRandomSampler(Sampler):
def __init__(self, data_source, weights):
assert len(data_source) == len(weights)
assert len(weights.shape) == 1
self.data_source = data_source
# Always use CPU
self._weights = torch.from_numpy(weights)
self.iter = self._create_iterator()
def update_weights(self, weights):
self._weights = weights
self.iter = self._create_iterator()
def _create_iterator(self):
return iter(
torch.multinomial(
self._weights, len(self._weights), replacement=True
).tolist()
)
def __iter__(self):
return self
def __next__(self):
try:
idx = next(self.iter)
except StopIteration:
self.iter = self._create_iterator()
idx = next(self.iter)
return idx
def __len__(self):
return 2 ** 62