|
import os |
|
from PIL import Image |
|
import cv2 |
|
import torch |
|
from torch.utils import data |
|
from torchvision import transforms |
|
from torchvision.transforms import functional as F |
|
import numbers |
|
import numpy as np |
|
import random |
|
|
|
|
|
|
|
|
|
class ImageDataTrain(data.Dataset): |
|
def __init__(self): |
|
|
|
self.sal_root = '/home/liuj/dataset/DUTS/DUTS-TR' |
|
self.sal_source = '/home/liuj/dataset/DUTS/DUTS-TR/train_pair_edge.lst' |
|
|
|
with open(self.sal_source, 'r') as f: |
|
self.sal_list = [x.strip() for x in f.readlines()] |
|
|
|
self.sal_num = len(self.sal_list) |
|
|
|
|
|
def __getitem__(self, item): |
|
|
|
|
|
sal_image = load_image(os.path.join(self.sal_root, self.sal_list[item%self.sal_num].split()[0])) |
|
sal_label = load_sal_label(os.path.join(self.sal_root, self.sal_list[item%self.sal_num].split()[1])) |
|
sal_edge = load_edge_label(os.path.join(self.sal_root, self.sal_list[item%self.sal_num].split()[2])) |
|
sal_image, sal_label, sal_edge = cv_random_flip(sal_image, sal_label, sal_edge) |
|
sal_image = torch.Tensor(sal_image) |
|
sal_label = torch.Tensor(sal_label) |
|
sal_edge = torch.Tensor(sal_edge) |
|
|
|
sample = {'sal_image': sal_image, 'sal_label': sal_label, 'sal_edge': sal_edge} |
|
return sample |
|
|
|
def __len__(self): |
|
|
|
return self.sal_num |
|
|
|
class ImageDataTest(data.Dataset): |
|
def __init__(self, test_mode=1, sal_mode='e'): |
|
if test_mode == 0: |
|
|
|
|
|
self.image_root = '/home/liuj/dataset/HED-BSDS_PASCAL/HED-BSDS/test/' |
|
self.image_source = '/home/liuj/dataset/HED-BSDS_PASCAL/HED-BSDS/test.lst' |
|
|
|
|
|
elif test_mode == 1: |
|
if sal_mode == 'e': |
|
self.image_root = '/home/liuj/dataset/saliency_test/ECSSD/Imgs/' |
|
self.image_source = '/home/liuj/dataset/saliency_test/ECSSD/test.lst' |
|
self.test_fold = '/media/ubuntu/disk/Result/saliency/ECSSD/' |
|
elif sal_mode == 'p': |
|
self.image_root = '/home/liuj/dataset/saliency_test/PASCALS/Imgs/' |
|
self.image_source = '/home/liuj/dataset/saliency_test/PASCALS/test.lst' |
|
self.test_fold = '/media/ubuntu/disk/Result/saliency/PASCALS/' |
|
elif sal_mode == 'd': |
|
self.image_root = '/home/liuj/dataset/saliency_test/DUTOMRON/Imgs/' |
|
self.image_source = '/home/liuj/dataset/saliency_test/DUTOMRON/test.lst' |
|
self.test_fold = '/media/ubuntu/disk/Result/saliency/DUTOMRON/' |
|
elif sal_mode == 'h': |
|
self.image_root = '/home/liuj/dataset/saliency_test/HKU-IS/Imgs/' |
|
self.image_source = '/home/liuj/dataset/saliency_test/HKU-IS/test.lst' |
|
self.test_fold = '/media/ubuntu/disk/Result/saliency/HKU-IS/' |
|
elif sal_mode == 's': |
|
self.image_root = '/home/liuj/dataset/saliency_test/SOD/Imgs/' |
|
self.image_source = '/home/liuj/dataset/saliency_test/SOD/test.lst' |
|
self.test_fold = '/media/ubuntu/disk/Result/saliency/SOD/' |
|
elif sal_mode == 'm': |
|
self.image_root = '/home/liuj/dataset/saliency_test/MSRA/Imgs/' |
|
self.image_source = '/home/liuj/dataset/saliency_test/MSRA/test.lst' |
|
elif sal_mode == 'o': |
|
self.image_root = '/home/liuj/dataset/saliency_test/SOC/TestSet/Imgs/' |
|
self.image_source = '/home/liuj/dataset/saliency_test/SOC/TestSet/test.lst' |
|
self.test_fold = '/media/ubuntu/disk/Result/saliency/SOC/' |
|
elif sal_mode == 't': |
|
self.image_root = '/home/liuj/dataset/DUTS/DUTS-TE/DUTS-TE-Image/' |
|
self.image_source = '/home/liuj/dataset/DUTS/DUTS-TE/test.lst' |
|
self.test_fold = '/media/ubuntu/disk/Result/saliency/DUTS/' |
|
elif test_mode == 2: |
|
|
|
self.image_root = '/home/liuj/dataset/SK-LARGE/images/test/' |
|
self.image_source = '/home/liuj/dataset/SK-LARGE/test.lst' |
|
|
|
with open(self.image_source, 'r') as f: |
|
self.image_list = [x.strip() for x in f.readlines()] |
|
|
|
self.image_num = len(self.image_list) |
|
|
|
def __getitem__(self, item): |
|
image, im_size = load_image_test(os.path.join(self.image_root, self.image_list[item])) |
|
image = torch.Tensor(image) |
|
|
|
return {'image': image, 'name': self.image_list[item%self.image_num], 'size': im_size} |
|
def save_folder(self): |
|
return self.test_fold |
|
|
|
def __len__(self): |
|
|
|
return self.image_num |
|
|
|
|
|
|
|
def get_loader(batch_size, mode='train', num_thread=1, test_mode=0, sal_mode='e'): |
|
shuffle = False |
|
if mode == 'train': |
|
shuffle = True |
|
dataset = ImageDataTrain() |
|
else: |
|
dataset = ImageDataTest(test_mode=test_mode, sal_mode=sal_mode) |
|
|
|
data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_thread) |
|
return data_loader, dataset |
|
|
|
def load_image(pah): |
|
if not os.path.exists(pah): |
|
print('File Not Exists') |
|
im = cv2.imread(pah) |
|
in_ = np.array(im, dtype=np.float32) |
|
|
|
|
|
in_ -= np.array((104.00699, 116.66877, 122.67892)) |
|
in_ = in_.transpose((2,0,1)) |
|
return in_ |
|
|
|
def load_image_test(pah): |
|
if not os.path.exists(pah): |
|
print('File Not Exists') |
|
im = cv2.imread(pah) |
|
in_ = np.array(im, dtype=np.float32) |
|
im_size = tuple(in_.shape[:2]) |
|
|
|
|
|
in_ -= np.array((104.00699, 116.66877, 122.67892)) |
|
in_ = in_.transpose((2,0,1)) |
|
return in_, im_size |
|
|
|
def load_edge_label(pah): |
|
""" |
|
pixels > 0.5 -> 1 |
|
Load label image as 1 x height x width integer array of label indices. |
|
The leading singleton dimension is required by the loss. |
|
""" |
|
if not os.path.exists(pah): |
|
print('File Not Exists') |
|
im = Image.open(pah) |
|
label = np.array(im, dtype=np.float32) |
|
if len(label.shape) == 3: |
|
label = label[:,:,0] |
|
|
|
label = label / 255. |
|
label[np.where(label > 0.5)] = 1. |
|
label = label[np.newaxis, ...] |
|
return label |
|
|
|
def load_skel_label(pah): |
|
""" |
|
pixels > 0 -> 1 |
|
Load label image as 1 x height x width integer array of label indices. |
|
The leading singleton dimension is required by the loss. |
|
""" |
|
if not os.path.exists(pah): |
|
print('File Not Exists') |
|
im = Image.open(pah) |
|
label = np.array(im, dtype=np.float32) |
|
if len(label.shape) == 3: |
|
label = label[:,:,0] |
|
|
|
label = label / 255. |
|
label[np.where(label > 0.)] = 1. |
|
label = label[np.newaxis, ...] |
|
return label |
|
|
|
def load_sal_label(pah): |
|
""" |
|
Load label image as 1 x height x width integer array of label indices. |
|
The leading singleton dimension is required by the loss. |
|
""" |
|
if not os.path.exists(pah): |
|
print('File Not Exists') |
|
im = Image.open(pah) |
|
label = np.array(im, dtype=np.float32) |
|
if len(label.shape) == 3: |
|
label = label[:,:,0] |
|
|
|
label = label / 255. |
|
label = label[np.newaxis, ...] |
|
return label |
|
|
|
def load_sem_label(pah): |
|
""" |
|
Load label image as 1 x height x width integer array of label indices. |
|
The leading singleton dimension is required by the loss. |
|
""" |
|
if not os.path.exists(pah): |
|
print('File Not Exists') |
|
im = Image.open(pah) |
|
label = np.array(im, dtype=np.float32) |
|
if len(label.shape) == 3: |
|
label = label[:,:,0] |
|
|
|
|
|
label = label[np.newaxis, ...] |
|
return label |
|
|
|
def edge_thres_transform(x, thres): |
|
|
|
y1 = torch.ones(x.size()) |
|
x = torch.where(x >= thres, y1, x) |
|
return x |
|
|
|
def skel_thres_transform(x, thres): |
|
y0 = torch.zeros(x.size()) |
|
y1 = torch.ones(x.size()) |
|
x = torch.where(x > thres, y1, y0) |
|
return x |
|
|
|
def cv_random_flip(img, label, edge): |
|
flip_flag = random.randint(0, 1) |
|
if flip_flag == 1: |
|
img = img[:,:,::-1].copy() |
|
label = label[:,:,::-1].copy() |
|
edge = edge[:,:,::-1].copy() |
|
return img, label, edge |
|
|
|
def cv_random_crop_flip(img, label, resize_size, crop_size, random_flip=True): |
|
def get_params(img_size, output_size): |
|
h, w = img_size |
|
th, tw = output_size |
|
if w == tw and h == th: |
|
return 0, 0, h, w |
|
i = random.randint(0, h - th) |
|
j = random.randint(0, w - tw) |
|
return i, j, th, tw |
|
if random_flip: |
|
flip_flag = random.randint(0, 1) |
|
img = img.transpose((1,2,0)) |
|
label = label[0,:,:] |
|
img = cv2.resize(img, (resize_size[1], resize_size[0]), interpolation=cv2.INTER_LINEAR) |
|
label = cv2.resize(label, (resize_size[1], resize_size[0]), interpolation=cv2.INTER_NEAREST) |
|
i, j, h, w = get_params(resize_size, crop_size) |
|
img = img[i:i+h, j:j+w, :].transpose((2,0,1)) |
|
label = label[i:i+h, j:j+w][np.newaxis, ...] |
|
if flip_flag == 1: |
|
img = img[:,:,::-1].copy() |
|
label = label[:,:,::-1].copy() |
|
return img, label |
|
|
|
def random_crop(img, label, size, padding=None, pad_if_needed=True, fill_img=(123, 116, 103), fill_label=0, padding_mode='constant'): |
|
|
|
def get_params(img, output_size): |
|
w, h = img.size |
|
th, tw = output_size |
|
if w == tw and h == th: |
|
return 0, 0, h, w |
|
|
|
i = random.randint(0, h - th) |
|
j = random.randint(0, w - tw) |
|
return i, j, th, tw |
|
|
|
if isinstance(size, numbers.Number): |
|
size = (int(size), int(size)) |
|
if padding is not None: |
|
img = F.pad(img, padding, fill_img, padding_mode) |
|
label = F.pad(label, padding, fill_label, padding_mode) |
|
|
|
|
|
if pad_if_needed and img.size[0] < size[1]: |
|
img = F.pad(img, (int((1 + size[1] - img.size[0]) / 2), 0), fill_img, padding_mode) |
|
label = F.pad(label, (int((1 + size[1] - label.size[0]) / 2), 0), fill_label, padding_mode) |
|
|
|
if pad_if_needed and img.size[1] < size[0]: |
|
img = F.pad(img, (0, int((1 + size[0] - img.size[1]) / 2)), fill_img, padding_mode) |
|
label = F.pad(label, (0, int((1 + size[0] - label.size[1]) / 2)), fill_label, padding_mode) |
|
|
|
i, j, h, w = get_params(img, size) |
|
return [F.crop(img, i, j, h, w), F.crop(label, i, j, h, w)] |
|
|