Spaces:
Running
Running
import functools | |
import random | |
import math | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import torch | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
import torchvision | |
from datasets import register | |
import cv2 | |
from math import pi | |
from torchvision.transforms import InterpolationMode | |
import torch.nn.functional as F | |
def to_mask(mask): | |
return transforms.ToTensor()( | |
transforms.Grayscale(num_output_channels=1)( | |
transforms.ToPILImage()(mask))) | |
def resize_fn(img, size): | |
return transforms.ToTensor()( | |
transforms.Resize(size)( | |
transforms.ToPILImage()(img))) | |
class ValDataset(Dataset): | |
def __init__(self, dataset, inp_size=None, augment=False): | |
self.dataset = dataset | |
self.inp_size = inp_size | |
self.augment = augment | |
self.img_transform = transforms.Compose([ | |
# transforms.Resize((inp_size, inp_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
self.mask_transform = transforms.Compose([ | |
transforms.Resize((inp_size, inp_size), interpolation=Image.NEAREST), | |
transforms.ToTensor(), | |
]) | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
img, mask = self.dataset[idx] | |
mask_name = mask | |
a = self.img_transform(img) | |
# b = self.mask_transform(mask) | |
# print(idx, mask.filename) | |
# b = cv2.imread(mask.filename,cv2.IMREAD_UNCHANGED) | |
b = cv2.imread(mask,cv2.IMREAD_UNCHANGED) | |
return { | |
'inp': self.img_transform(img), | |
'gt': torch.tensor(b), | |
'name': mask_name, | |
'filp': False | |
# 'idx': idx | |
} | |
class TrainDataset(Dataset): | |
def __init__(self, dataset, size_min=None, size_max=None, inp_size=None, | |
augment=False, gt_resize=None): | |
self.dataset = dataset | |
self.size_min = size_min | |
if size_max is None: | |
size_max = size_min | |
self.size_max = size_max | |
self.augment = augment | |
self.gt_resize = gt_resize | |
self.inp_size = inp_size | |
self.img_transform = transforms.Compose([ | |
transforms.Resize((self.inp_size, self.inp_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
self.inverse_transform = transforms.Compose([ | |
transforms.Normalize(mean=[0., 0., 0.], | |
std=[1/0.229, 1/0.224, 1/0.225]), | |
transforms.Normalize(mean=[-0.485, -0.456, -0.406], | |
std=[1, 1, 1]) | |
]) | |
self.mask_transform = transforms.Compose([ | |
transforms.Resize((self.inp_size, self.inp_size)), | |
transforms.ToTensor(), | |
]) | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
# print('lodd****',idx,self.dataset[idx]) | |
img, mask = self.dataset[idx] | |
mask_name = mask | |
# print('befor mask', mask) | |
#new add | |
# print(idx, mask.filename, img.size) | |
# mask = cv2.imread(mask.filename, cv2.IMREAD_UNCHANGED) | |
mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED) | |
# print('befor mask', mask) | |
# random filp | |
if random.random() < 0.5: | |
img = img.transpose(Image.FLIP_LEFT_RIGHT) | |
# mask = mask.transpose(Image.FLIP_LEFT_RIGHT) | |
mask = cv2.flip(mask, 1) | |
img = transforms.Resize((self.inp_size, self.inp_size))(img) | |
# mask = transforms.Resize((self.inp_size, self.inp_size), interpolation=InterpolationMode.NEAREST)(mask) | |
mask = torch.from_numpy(mask) | |
# print('behind mask', mask) | |
return { | |
'inp': self.img_transform(img), | |
# 'gt': self.mask_transform(mask) | |
'gt': mask, | |
'name': mask_name, | |
# 'idx': idx | |
} | |
class TrainDataset(Dataset): | |
def __init__(self, dataset, size_min=None, size_max=None, inp_size=None, | |
augment=False, gt_resize=None): | |
self.dataset = dataset | |
self.size_min = size_min | |
if size_max is None: | |
size_max = size_min | |
self.size_max = size_max | |
self.augment = augment | |
self.gt_resize = gt_resize | |
self.inp_size = inp_size | |
self.img_transform = transforms.Compose([ | |
transforms.Resize((self.inp_size, self.inp_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
self.inverse_transform = transforms.Compose([ | |
transforms.Normalize(mean=[0., 0., 0.], | |
std=[1/0.229, 1/0.224, 1/0.225]), | |
transforms.Normalize(mean=[-0.485, -0.456, -0.406], | |
std=[1, 1, 1]) | |
]) | |
self.mask_transform = transforms.Compose([ | |
transforms.Resize((self.inp_size, self.inp_size)), | |
transforms.ToTensor(), | |
]) | |
def __len__(self): | |
return len(self.dataset) | |
# return sum(len(dataset) for dataset in self.datasets) | |
def __getitem__(self, idx): | |
# print('lodd****',idx,self.dataset[idx]) | |
# print('+++++',idx) | |
img, mask = self.dataset[idx] | |
# print('befor mask', mask) | |
#new add | |
# print('****',idx, mask) | |
mask_name = mask | |
mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED) | |
# print('****',mask) | |
# print('befor mask', mask) | |
# random filp | |
if random.random() < 0.5: | |
img = img.transpose(Image.FLIP_LEFT_RIGHT) | |
# mask = mask.transpose(Image.FLIP_LEFT_RIGHT) | |
mask = cv2.flip(mask, 1) | |
img = transforms.Resize((self.inp_size, self.inp_size))(img) | |
# mask = transforms.Resize((self.inp_size, self.inp_size), interpolation=InterpolationMode.NEAREST)(mask) | |
mask = torch.from_numpy(mask) | |
# print('behind mask', mask) | |
return { | |
'inp': self.img_transform(img), | |
# 'gt': self.mask_transform(mask) | |
'gt': mask, | |
'name': mask_name | |
} | |
class ValDataset(Dataset): | |
def __init__(self, dataset, inp_size=None, augment=False): | |
self.dataset = dataset | |
self.inp_size = inp_size | |
self.augment = augment | |
self.img_transform = transforms.Compose([ | |
transforms.Resize((inp_size, inp_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
self.mask_transform = transforms.Compose([ | |
transforms.Resize((inp_size, inp_size), interpolation=Image.NEAREST), | |
transforms.ToTensor(), | |
]) | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
img, mask = self.dataset[idx] | |
a = self.img_transform(img) | |
# b = self.mask_transform(mask) | |
mask_name = mask | |
# print(idx, mask.filename) | |
# b = cv2.imread(mask.filename,cv2.IMREAD_UNCHANGED) | |
b = cv2.imread(mask, cv2.IMREAD_UNCHANGED) | |
return { | |
'inp': self.img_transform(img), | |
'gt': torch.tensor(b), | |
'name': mask_name | |
# 'idx': idx | |
} |