taesiri's picture
Initial Commit
8390f90
raw
history blame
2.44 kB
r""" PF-WILLOW dataset """
import os
import pandas as pd
import numpy as np
import torch
from .dataset import CorrespondenceDataset
class PFWillowDataset(CorrespondenceDataset):
def __init__(self, benchmark, datapath, thres, split):
r"""PF-WILLOW dataset constructor"""
super(PFWillowDataset, self).__init__(benchmark, datapath, thres, split)
self.train_data = pd.read_csv(self.spt_path)
self.src_imnames = np.array(self.train_data.iloc[:, 0])
self.trg_imnames = np.array(self.train_data.iloc[:, 1])
self.src_kps = self.train_data.iloc[:, 2:22].values
self.trg_kps = self.train_data.iloc[:, 22:].values
self.cls = ['car(G)', 'car(M)', 'car(S)', 'duck(S)',
'motorbike(G)', 'motorbike(M)', 'motorbike(S)',
'winebottle(M)', 'winebottle(wC)', 'winebottle(woC)']
self.cls_ids = list(map(lambda names: self.cls.index(names.split('/')[1]), self.src_imnames))
self.src_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.src_imnames))
self.trg_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.trg_imnames))
def __getitem__(self, idx):
r""" Constructs and returns a batch for PF-WILLOW dataset """
batch = super(PFWillowDataset, self).__getitem__(idx)
batch['pckthres'] = self.get_pckthres(batch)
return batch
def get_pckthres(self, batch):
r""" Computes PCK threshold """
if self.thres == 'bbox':
return max(batch['trg_kps'].max(1)[0] - batch['trg_kps'].min(1)[0]).clone()
elif self.thres == 'img':
return torch.tensor(max(batch['trg_img'].size()[1], batch['trg_img'].size()[2]))
else:
raise Exception('Invalid pck evaluation level: %s' % self.thres)
def get_points(self, pts_list, idx, org_imsize):
r""" Returns key-points of an image """
point_coords = pts_list[idx, :].reshape(2, 10)
point_coords = torch.tensor(point_coords.astype(np.float32))
xy, n_pts = point_coords.size()
pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 2
x_crds = point_coords[0] * (self.img_size / org_imsize[0])
y_crds = point_coords[1] * (self.img_size / org_imsize[1])
kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1)
return kps, n_pts