yeq6x's picture
keypoints
19d010a
import torch
import torchvision
from torchvision import transforms
import random
from PIL import Image
import os
import pandas as pd
from utils import RandomAffineAndRetMat
def load_filenames(data_dir):
# 画像の拡張子のみ
img_exts = ['.jpg', '.jpeg', '.png', '.bmp', '.ppm', '.pgm', '.tif', '.tiff']
filenames = [f for f in os.listdir(data_dir) if os.path.splitext(f)[1].lower() in img_exts]
return filenames
def load_keypoints(label_path):
label_data = pd.read_json(label_path)
label_data = label_data.sort_index()
tmp_points = []
for o in label_data.data[0:1000]:
tmps = []
for i in range(60):
tmps.append(o['points'][str(i)]['x'])
tmps.append(o['points'][str(i)]['y'])
tmp_points.append(tmps) # datanum
df_points = pd.DataFrame(tmp_points)
df_points = df_points.iloc[:,[
*list(range(0,16*2+1,4)), *list(range(1,16*2+2,4)),
*list(range(27*2,36*2+1,4)), *list(range(27*2+1,36*2+2,4)),
*list(range(37*2,46*2+1,4)), *list(range(37*2+1,46*2+2,4)),
# 49*2, 49*2+1,
# *list(range(50*2,55*2+1,4)), *list(range(50*2+1,55*2+2,4)),
28*2, 28*2+1,
30*2, 30*2+1,
34*2, 34*2+1,
38*2, 38*2+1,
40*2, 40*2+1,
44*2, 44*2+1,
]]
df_points = df_points.sort_index(axis=1)
df_points.columns = list(range(len(df_points.columns)))
# df_points[0:500].iloc[0]
return df_points
class MyDataset:
def __init__(self, X, valid=False, img_dir='resources/trainB/', img_size=256):
self.X = X
self.valid = valid
self.img_dir = img_dir
self.img_size = img_size
def __len__(self):
return len(self.X)
def __getitem__(self, index):
# 画像を読み込んでトランスフォームを適用
f = self.img_dir + self.X[index]
original_X = Image.open(f)
trans = [
transforms.ToTensor(),
# transforms.Normalize(mean=means, std=stds),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.15),
transforms.RandomGrayscale(0.3),
]
transform = transforms.Compose(trans)
xlist = []
matlist = []
is_flip = random.randint(0, 1) # 同じ画像はフリップ
for i in range(2):
af = RandomAffineAndRetMat(
degrees=[-30, 30],
translate=(0.1, 0.1), scale=(0.8, 1.2),
# fill=(random.random(), random.random(), random.random()),
fill=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
shear=[-10, 10],
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
)
X, affine_matrix = af(transforms.Resize(self.img_size)(original_X))
# randomflip
if is_flip == 1:
X = transforms.RandomHorizontalFlip(1.)(X)
flip_matrix = torch.tensor([[-1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]])
affine_matrix = torch.matmul(flip_matrix, affine_matrix)
xlist.append(transform(X))
matlist.append(affine_matrix)
X = torch.stack(xlist)
mat = torch.stack(matlist)
return X, mat, f
class ImageKeypointDataset:
def __init__(self, X, y, valid=False, img_dir='resources/trainB/', img_size=256):
self.X = X
self.y = y
self.valid = valid
self.img_dir = img_dir
self.img_size = img_size
# if not valid:
trans = [
transforms.Resize(self.img_size),
transforms.ToTensor(),
# transforms.Normalize(mean=means, std=stds),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
]
self.trans = transforms.Compose(trans)
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if type(index) is slice:
if index.step is None:
return (torch.stack([self.get_one_X(i) for i in range(index.start, index.stop)]),
torch.stack([self.get_one_y(i) for i in range(index.start, index.stop)]))
else:
return (torch.stack([self.get_one_X(i) for i in range(index.start, index.stop, index.step)]),
torch.stack([self.get_one_y(i) for i in range(index.start, index.stop, index.step)]))
if type(index) is int:
return self.get_one_X(index), self.get_one_y(index)
def get_one_X(self, index):
f = self.img_dir + self.X[index]
X = Image.open(f)
X = self.trans(X)
return X
def get_one_y(self, index):
y = self.y.iloc[index].copy()
y = torch.tensor(y)
y = y.float()
y = y.reshape(25,2)
return y