Spaces:
Running
Running
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision.io import read_image, ImageReadMode | |
import numpy as np | |
def denorm_img(img: torch.Tensor) -> torch.Tensor: | |
std = torch.Tensor([0.229, 0.224, 0.225]).reshape(-1, 1, 1) | |
mean = torch.Tensor([0.485, 0.456, 0.406]).reshape(-1, 1, 1) | |
return torch.clip(img * std + mean, min=0, max=1) | |
class StyleContentDataset(Dataset): | |
def __init__(self, style_imgs, content_imgs, transform=None, normalize=None): | |
self.style_imgs = style_imgs | |
self.content_imgs = content_imgs | |
self.transform = transform | |
self.normalize = normalize | |
def __len__(self): | |
if len(self.style_imgs) < len(self.content_imgs): | |
return len(self.style_imgs) | |
else: | |
return len(self.content_imgs) | |
def __getitem__(self, idx): | |
try: | |
style = read_image(self.style_imgs[idx], ImageReadMode.RGB).float() / 255.0 | |
content = read_image(self.content_imgs[idx], ImageReadMode.RGB).float() / 255.0 | |
except RuntimeError: | |
print(self.style_imgs[idx]) | |
print(self.content_imgs[idx]) | |
style = read_image(self.style_imgs[0], ImageReadMode.RGB).float() / 255.0 | |
content = read_image(self.content_imgs[0], ImageReadMode.RGB).float() / 255.0 | |
if self.normalize: | |
style = self.normalize(style) | |
content = self.normalize(content) | |
if self.transform: | |
style = self.transform(style) | |
content = self.transform(content) | |
return style, content | |
class DataStore(): | |
def __init__(self, dataset: StyleContentDataset, batch_size, shuffle=False): | |
self.dataset = dataset | |
self.dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=shuffle, num_workers=2) | |
self.iterator = iter(self.dataloader) | |
def get(self): | |
try: | |
style, content = next(self.iterator) | |
except (StopIteration): | |
# print('| Repeating |') | |
# np.random.shuffle(self.dataset.style_imgs) | |
self.iterator = iter(self.dataloader) | |
style, content = next(self.iterator) | |
return style, content |