RingMo-SAM / datasets /wrappers.py
AI-Cyber's picture
Upload 123 files
8d7921b
raw
history blame contribute delete
7.84 kB
import functools
import random
import math
from PIL import Image
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision
from datasets import register
import cv2
from math import pi
from torchvision.transforms import InterpolationMode
import torch.nn.functional as F
def to_mask(mask):
return transforms.ToTensor()(
transforms.Grayscale(num_output_channels=1)(
transforms.ToPILImage()(mask)))
def resize_fn(img, size):
return transforms.ToTensor()(
transforms.Resize(size)(
transforms.ToPILImage()(img)))
@register('val')
class ValDataset(Dataset):
def __init__(self, dataset, inp_size=None, augment=False):
self.dataset = dataset
self.inp_size = inp_size
self.augment = augment
self.img_transform = transforms.Compose([
# transforms.Resize((inp_size, inp_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
self.mask_transform = transforms.Compose([
transforms.Resize((inp_size, inp_size), interpolation=Image.NEAREST),
transforms.ToTensor(),
])
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img, mask = self.dataset[idx]
mask_name = mask
a = self.img_transform(img)
# b = self.mask_transform(mask)
# print(idx, mask.filename)
# b = cv2.imread(mask.filename,cv2.IMREAD_UNCHANGED)
b = cv2.imread(mask,cv2.IMREAD_UNCHANGED)
return {
'inp': self.img_transform(img),
'gt': torch.tensor(b),
'name': mask_name,
'filp': False
# 'idx': idx
}
@register('train')
class TrainDataset(Dataset):
def __init__(self, dataset, size_min=None, size_max=None, inp_size=None,
augment=False, gt_resize=None):
self.dataset = dataset
self.size_min = size_min
if size_max is None:
size_max = size_min
self.size_max = size_max
self.augment = augment
self.gt_resize = gt_resize
self.inp_size = inp_size
self.img_transform = transforms.Compose([
transforms.Resize((self.inp_size, self.inp_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
self.inverse_transform = transforms.Compose([
transforms.Normalize(mean=[0., 0., 0.],
std=[1/0.229, 1/0.224, 1/0.225]),
transforms.Normalize(mean=[-0.485, -0.456, -0.406],
std=[1, 1, 1])
])
self.mask_transform = transforms.Compose([
transforms.Resize((self.inp_size, self.inp_size)),
transforms.ToTensor(),
])
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
# print('lodd****',idx,self.dataset[idx])
img, mask = self.dataset[idx]
mask_name = mask
# print('befor mask', mask)
#new add
# print(idx, mask.filename, img.size)
# mask = cv2.imread(mask.filename, cv2.IMREAD_UNCHANGED)
mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
# print('befor mask', mask)
# random filp
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
# mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
mask = cv2.flip(mask, 1)
img = transforms.Resize((self.inp_size, self.inp_size))(img)
# mask = transforms.Resize((self.inp_size, self.inp_size), interpolation=InterpolationMode.NEAREST)(mask)
mask = torch.from_numpy(mask)
# print('behind mask', mask)
return {
'inp': self.img_transform(img),
# 'gt': self.mask_transform(mask)
'gt': mask,
'name': mask_name,
# 'idx': idx
}
@register('train_multi_task')
class TrainDataset(Dataset):
def __init__(self, dataset, size_min=None, size_max=None, inp_size=None,
augment=False, gt_resize=None):
self.dataset = dataset
self.size_min = size_min
if size_max is None:
size_max = size_min
self.size_max = size_max
self.augment = augment
self.gt_resize = gt_resize
self.inp_size = inp_size
self.img_transform = transforms.Compose([
transforms.Resize((self.inp_size, self.inp_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
self.inverse_transform = transforms.Compose([
transforms.Normalize(mean=[0., 0., 0.],
std=[1/0.229, 1/0.224, 1/0.225]),
transforms.Normalize(mean=[-0.485, -0.456, -0.406],
std=[1, 1, 1])
])
self.mask_transform = transforms.Compose([
transforms.Resize((self.inp_size, self.inp_size)),
transforms.ToTensor(),
])
def __len__(self):
return len(self.dataset)
# return sum(len(dataset) for dataset in self.datasets)
def __getitem__(self, idx):
# print('lodd****',idx,self.dataset[idx])
# print('+++++',idx)
img, mask = self.dataset[idx]
# print('befor mask', mask)
#new add
# print('****',idx, mask)
mask_name = mask
mask = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
# print('****',mask)
# print('befor mask', mask)
# random filp
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
# mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
mask = cv2.flip(mask, 1)
img = transforms.Resize((self.inp_size, self.inp_size))(img)
# mask = transforms.Resize((self.inp_size, self.inp_size), interpolation=InterpolationMode.NEAREST)(mask)
mask = torch.from_numpy(mask)
# print('behind mask', mask)
return {
'inp': self.img_transform(img),
# 'gt': self.mask_transform(mask)
'gt': mask,
'name': mask_name
}
@register('val_multi_task')
class ValDataset(Dataset):
def __init__(self, dataset, inp_size=None, augment=False):
self.dataset = dataset
self.inp_size = inp_size
self.augment = augment
self.img_transform = transforms.Compose([
transforms.Resize((inp_size, inp_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
self.mask_transform = transforms.Compose([
transforms.Resize((inp_size, inp_size), interpolation=Image.NEAREST),
transforms.ToTensor(),
])
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
img, mask = self.dataset[idx]
a = self.img_transform(img)
# b = self.mask_transform(mask)
mask_name = mask
# print(idx, mask.filename)
# b = cv2.imread(mask.filename,cv2.IMREAD_UNCHANGED)
b = cv2.imread(mask, cv2.IMREAD_UNCHANGED)
return {
'inp': self.img_transform(img),
'gt': torch.tensor(b),
'name': mask_name
# 'idx': idx
}