File size: 5,443 Bytes
8390f90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
r""" Superclass for semantic correspondence datasets """
import os
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import torch
from model.base.geometry import Geometry
class CorrespondenceDataset(Dataset):
r""" Parent class of PFPascal, PFWillow, and SPair """
def __init__(self, benchmark, datapath, thres, split):
r""" CorrespondenceDataset constructor """
super(CorrespondenceDataset, self).__init__()
# {Directory name, Layout path, Image path, Annotation path, PCK threshold}
self.metadata = {
'pfwillow': ('PF-WILLOW',
'test_pairs.csv',
'',
'',
'bbox'),
'pfpascal': ('PF-PASCAL',
'_pairs.csv',
'JPEGImages',
'Annotations',
'img'),
'spair': ('SPair-71k',
'Layout/large',
'JPEGImages',
'PairAnnotation',
'bbox')
}
# Directory path for train, val, or test splits
base_path = os.path.join(os.path.abspath(datapath), self.metadata[benchmark][0])
if benchmark == 'pfpascal':
self.spt_path = os.path.join(base_path, split+'_pairs.csv')
elif benchmark == 'spair':
self.spt_path = os.path.join(base_path, self.metadata[benchmark][1], split+'.txt')
else:
self.spt_path = os.path.join(base_path, self.metadata[benchmark][1])
# Directory path for images
self.img_path = os.path.join(base_path, self.metadata[benchmark][2])
# Directory path for annotations
if benchmark == 'spair':
self.ann_path = os.path.join(base_path, self.metadata[benchmark][3], split)
else:
self.ann_path = os.path.join(base_path, self.metadata[benchmark][3])
# Miscellaneous
self.max_pts = 40
self.split = split
self.img_size = Geometry.img_size
self.benchmark = benchmark
self.range_ts = torch.arange(self.max_pts)
self.thres = self.metadata[benchmark][4] if thres == 'auto' else thres
self.transform = transforms.Compose([transforms.Resize((self.img_size, self.img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
# To get initialized in subclass constructors
self.train_data = []
self.src_imnames = []
self.trg_imnames = []
self.cls = []
self.cls_ids = []
self.src_kps = []
self.trg_kps = []
def __len__(self):
r""" Returns the number of pairs """
return len(self.train_data)
def __getitem__(self, idx):
r""" Constructs and return a batch """
# Image name
batch = dict()
batch['src_imname'] = self.src_imnames[idx]
batch['trg_imname'] = self.trg_imnames[idx]
# Object category
batch['category_id'] = self.cls_ids[idx]
batch['category'] = self.cls[batch['category_id']]
# Image as numpy (original width, original height)
src_pil = self.get_image(self.src_imnames, idx)
trg_pil = self.get_image(self.trg_imnames, idx)
batch['src_imsize'] = src_pil.size
batch['trg_imsize'] = trg_pil.size
# Image as tensor
batch['src_img'] = self.transform(src_pil)
batch['trg_img'] = self.transform(trg_pil)
# Key-points (re-scaled)
batch['src_kps'], num_pts = self.get_points(self.src_kps, idx, src_pil.size)
batch['trg_kps'], _ = self.get_points(self.trg_kps, idx, trg_pil.size)
batch['n_pts'] = torch.tensor(num_pts)
# Total number of pairs in training split
batch['datalen'] = len(self.train_data)
return batch
def get_image(self, imnames, idx):
r""" Reads PIL image from path """
path = os.path.join(self.img_path, imnames[idx])
return Image.open(path).convert('RGB')
def get_pckthres(self, batch, imsize):
r""" Computes PCK threshold """
if self.thres == 'bbox':
bbox = batch['trg_bbox'].clone()
bbox_w = (bbox[2] - bbox[0])
bbox_h = (bbox[3] - bbox[1])
pckthres = torch.max(bbox_w, bbox_h)
elif self.thres == 'img':
imsize_t = batch['trg_img'].size()
pckthres = torch.tensor(max(imsize_t[1], imsize_t[2]))
else:
raise Exception('Invalid pck threshold type: %s' % self.thres)
return pckthres.float()
def get_points(self, pts_list, idx, org_imsize):
r""" Returns key-points of an image """
xy, n_pts = pts_list[idx].size()
pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 2
x_crds = pts_list[idx][0] * (self.img_size / org_imsize[0])
y_crds = pts_list[idx][1] * (self.img_size / org_imsize[1])
kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1)
return kps, n_pts
|