|
r""" PF-WILLOW dataset """
|
|
|
|
import os
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
import torch
|
|
|
|
from .dataset import CorrespondenceDataset
|
|
|
|
|
|
class PFWillowDataset(CorrespondenceDataset):
|
|
|
|
def __init__(self, benchmark, datapath, thres, split):
|
|
r"""PF-WILLOW dataset constructor"""
|
|
super(PFWillowDataset, self).__init__(benchmark, datapath, thres, split)
|
|
|
|
self.train_data = pd.read_csv(self.spt_path)
|
|
self.src_imnames = np.array(self.train_data.iloc[:, 0])
|
|
self.trg_imnames = np.array(self.train_data.iloc[:, 1])
|
|
self.src_kps = self.train_data.iloc[:, 2:22].values
|
|
self.trg_kps = self.train_data.iloc[:, 22:].values
|
|
self.cls = ['car(G)', 'car(M)', 'car(S)', 'duck(S)',
|
|
'motorbike(G)', 'motorbike(M)', 'motorbike(S)',
|
|
'winebottle(M)', 'winebottle(wC)', 'winebottle(woC)']
|
|
self.cls_ids = list(map(lambda names: self.cls.index(names.split('/')[1]), self.src_imnames))
|
|
self.src_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.src_imnames))
|
|
self.trg_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.trg_imnames))
|
|
|
|
def __getitem__(self, idx):
|
|
r""" Constructs and returns a batch for PF-WILLOW dataset """
|
|
batch = super(PFWillowDataset, self).__getitem__(idx)
|
|
batch['pckthres'] = self.get_pckthres(batch)
|
|
|
|
return batch
|
|
|
|
def get_pckthres(self, batch):
|
|
r""" Computes PCK threshold """
|
|
if self.thres == 'bbox':
|
|
return max(batch['trg_kps'].max(1)[0] - batch['trg_kps'].min(1)[0]).clone()
|
|
elif self.thres == 'img':
|
|
return torch.tensor(max(batch['trg_img'].size()[1], batch['trg_img'].size()[2]))
|
|
else:
|
|
raise Exception('Invalid pck evaluation level: %s' % self.thres)
|
|
|
|
def get_points(self, pts_list, idx, org_imsize):
|
|
r""" Returns key-points of an image """
|
|
point_coords = pts_list[idx, :].reshape(2, 10)
|
|
point_coords = torch.tensor(point_coords.astype(np.float32))
|
|
xy, n_pts = point_coords.size()
|
|
pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 2
|
|
x_crds = point_coords[0] * (self.img_size / org_imsize[0])
|
|
y_crds = point_coords[1] * (self.img_size / org_imsize[1])
|
|
kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1)
|
|
|
|
return kps, n_pts
|
|
|