stable-fashion / data /base_dataset.py
ovshake
add app.py and related files
6724ca0
raw
history blame
No virus
5.5 kB
import os
from PIL import Image
import cv2
import numpy as np
import random
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return "BaseDataset"
def initialize(self, opt):
pass
class Rescale_fixed(object):
"""Rescale the input image into given size.
Args:
(w,h) (tuple): output size or x (int) then resized will be done in (x,x).
"""
def __init__(self, output_size):
self.output_size = output_size
def __call__(self, image):
return image.resize(self.output_size, Image.BICUBIC)
class Rescale_custom(object):
"""Rescale the input image and target image into randomly selected size with lower bound of min_size arg.
Args:
min_size (int): Minimum desired output size.
"""
def __init__(self, min_size, max_size):
assert isinstance(min_size, (int, float))
self.min_size = min_size
self.max_size = max_size
def __call__(self, sample):
input_image, target_image = sample["input_image"], sample["target_image"]
assert input_image.size == target_image.size
w, h = input_image.size
# Randomly select size to resize
if min(self.max_size, h, w) > self.min_size:
self.output_size = np.random.randint(
self.min_size, min(self.max_size, h, w)
)
else:
self.output_size = self.min_size
# calculate new size by keeping aspect ratio same
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
new_w, new_h = int(new_w), int(new_h)
input_image = input_image.resize((new_w, new_h), Image.BICUBIC)
target_image = target_image.resize((new_w, new_h), Image.BICUBIC)
return {"input_image": input_image, "target_image": target_image}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __init__(self):
self.totensor = transforms.ToTensor()
def __call__(self, sample):
input_image, target_image = sample["input_image"], sample["target_image"]
return {
"input_image": self.totensor(input_image),
"target_image": self.totensor(target_image),
}
class RandomCrop_custom(object):
"""Crop randomly the image in a sample.
Args:
output_size (tuple or int): Desired output size. If int, square crop
is made.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
self.randomcrop = transforms.RandomCrop(self.output_size)
def __call__(self, sample):
input_image, target_image = sample["input_image"], sample["target_image"]
cropped_imgs = self.randomcrop(torch.cat((input_image, target_image)))
return {
"input_image": cropped_imgs[
:3,
:,
],
"target_image": cropped_imgs[
3:,
:,
],
}
class Normalize_custom(object):
"""Normalize given dict into given mean and standard dev
Args:
mean (tuple or int): Desired mean to substract from dict's tensors
std (tuple or int): Desired std to divide from dict's tensors
"""
def __init__(self, mean, std):
assert isinstance(mean, (float, tuple))
if isinstance(mean, float):
self.mean = (mean, mean, mean)
else:
assert len(mean) == 3
self.mean = mean
if isinstance(std, float):
self.std = (std, std, std)
else:
assert len(std) == 3
self.std = std
self.normalize = transforms.Normalize(self.mean, self.std)
def __call__(self, sample):
input_image, target_image = sample["input_image"], sample["target_image"]
return {
"input_image": self.normalize(input_image),
"target_image": self.normalize(target_image),
}
class Normalize_image(object):
"""Normalize given tensor into given mean and standard dev
Args:
mean (float): Desired mean to substract from tensors
std (float): Desired std to divide from tensors
"""
def __init__(self, mean, std):
assert isinstance(mean, (float))
if isinstance(mean, float):
self.mean = mean
if isinstance(std, float):
self.std = std
self.normalize_1 = transforms.Normalize(self.mean, self.std)
self.normalize_3 = transforms.Normalize([self.mean] * 3, [self.std] * 3)
self.normalize_18 = transforms.Normalize([self.mean] * 18, [self.std] * 18)
def __call__(self, image_tensor):
if image_tensor.shape[0] == 1:
return self.normalize_1(image_tensor)
elif image_tensor.shape[0] == 3:
return self.normalize_3(image_tensor)
elif image_tensor.shape[0] == 18:
return self.normalize_18(image_tensor)
else:
assert "Please set proper channels! Normlization implemented only for 1, 3 and 18"