Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
from torch.utils.data import Dataset, Sampler | |
# TODO: move this to more reasonable place | |
from rlkit.data_management.obs_dict_replay_buffer import normalize_image | |
class ImageDataset(Dataset): | |
def __init__(self, images, should_normalize=True): | |
super().__init__() | |
self.dataset = images | |
self.dataset_len = len(self.dataset) | |
assert should_normalize == (images.dtype == np.uint8) | |
self.should_normalize = should_normalize | |
def __len__(self): | |
return self.dataset_len | |
def __getitem__(self, idxs): | |
samples = self.dataset[idxs, :] | |
if self.should_normalize: | |
samples = normalize_image(samples) | |
return np.float32(samples) | |
class InfiniteRandomSampler(Sampler): | |
def __init__(self, data_source): | |
self.data_source = data_source | |
self.iter = iter(torch.randperm(len(self.data_source)).tolist()) | |
def __iter__(self): | |
return self | |
def __next__(self): | |
try: | |
idx = next(self.iter) | |
except StopIteration: | |
self.iter = iter(torch.randperm(len(self.data_source)).tolist()) | |
idx = next(self.iter) | |
return idx | |
def __len__(self): | |
return 2 ** 62 | |
class InfiniteWeightedRandomSampler(Sampler): | |
def __init__(self, data_source, weights): | |
assert len(data_source) == len(weights) | |
assert len(weights.shape) == 1 | |
self.data_source = data_source | |
# Always use CPU | |
self._weights = torch.from_numpy(weights) | |
self.iter = self._create_iterator() | |
def update_weights(self, weights): | |
self._weights = weights | |
self.iter = self._create_iterator() | |
def _create_iterator(self): | |
return iter( | |
torch.multinomial( | |
self._weights, len(self._weights), replacement=True | |
).tolist() | |
) | |
def __iter__(self): | |
return self | |
def __next__(self): | |
try: | |
idx = next(self.iter) | |
except StopIteration: | |
self.iter = self._create_iterator() | |
idx = next(self.iter) | |
return idx | |
def __len__(self): | |
return 2 ** 62 | |