Spaces:
Sleeping
Sleeping
import torch | |
import torchvision | |
from torchvision import transforms | |
import random | |
from PIL import Image | |
import os | |
import pandas as pd | |
from utils import RandomAffineAndRetMat | |
def load_filenames(data_dir): | |
# 画像の拡張子のみ | |
img_exts = ['.jpg', '.jpeg', '.png', '.bmp', '.ppm', '.pgm', '.tif', '.tiff'] | |
filenames = [f for f in os.listdir(data_dir) if os.path.splitext(f)[1].lower() in img_exts] | |
return filenames | |
def load_keypoints(label_path): | |
label_data = pd.read_json(label_path) | |
label_data = label_data.sort_index() | |
tmp_points = [] | |
for o in label_data.data[0:1000]: | |
tmps = [] | |
for i in range(60): | |
tmps.append(o['points'][str(i)]['x']) | |
tmps.append(o['points'][str(i)]['y']) | |
tmp_points.append(tmps) # datanum | |
df_points = pd.DataFrame(tmp_points) | |
df_points = df_points.iloc[:,[ | |
*list(range(0,16*2+1,4)), *list(range(1,16*2+2,4)), | |
*list(range(27*2,36*2+1,4)), *list(range(27*2+1,36*2+2,4)), | |
*list(range(37*2,46*2+1,4)), *list(range(37*2+1,46*2+2,4)), | |
# 49*2, 49*2+1, | |
# *list(range(50*2,55*2+1,4)), *list(range(50*2+1,55*2+2,4)), | |
28*2, 28*2+1, | |
30*2, 30*2+1, | |
34*2, 34*2+1, | |
38*2, 38*2+1, | |
40*2, 40*2+1, | |
44*2, 44*2+1, | |
]] | |
df_points = df_points.sort_index(axis=1) | |
df_points.columns = list(range(len(df_points.columns))) | |
# df_points[0:500].iloc[0] | |
return df_points | |
class MyDataset: | |
def __init__(self, X, valid=False, img_dir='resources/trainB/', img_size=256): | |
self.X = X | |
self.valid = valid | |
self.img_dir = img_dir | |
self.img_size = img_size | |
def __len__(self): | |
return len(self.X) | |
def __getitem__(self, index): | |
# 画像を読み込んでトランスフォームを適用 | |
f = self.img_dir + self.X[index] | |
original_X = Image.open(f) | |
trans = [ | |
transforms.ToTensor(), | |
# transforms.Normalize(mean=means, std=stds), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.15), | |
transforms.RandomGrayscale(0.3), | |
] | |
transform = transforms.Compose(trans) | |
xlist = [] | |
matlist = [] | |
is_flip = random.randint(0, 1) # 同じ画像はフリップ | |
for i in range(2): | |
af = RandomAffineAndRetMat( | |
degrees=[-30, 30], | |
translate=(0.1, 0.1), scale=(0.8, 1.2), | |
# fill=(random.random(), random.random(), random.random()), | |
fill=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)), | |
shear=[-10, 10], | |
interpolation=torchvision.transforms.InterpolationMode.BILINEAR, | |
) | |
X, affine_matrix = af(transforms.Resize(self.img_size)(original_X)) | |
# randomflip | |
if is_flip == 1: | |
X = transforms.RandomHorizontalFlip(1.)(X) | |
flip_matrix = torch.tensor([[-1., 0., 0.], | |
[0., 1., 0.], | |
[0., 0., 1.]]) | |
affine_matrix = torch.matmul(flip_matrix, affine_matrix) | |
xlist.append(transform(X)) | |
matlist.append(affine_matrix) | |
X = torch.stack(xlist) | |
mat = torch.stack(matlist) | |
return X, mat, f | |
class ImageKeypointDataset: | |
def __init__(self, X, y, valid=False, img_dir='resources/trainB/', img_size=256): | |
self.X = X | |
self.y = y | |
self.valid = valid | |
self.img_dir = img_dir | |
self.img_size = img_size | |
# if not valid: | |
trans = [ | |
transforms.Resize(self.img_size), | |
transforms.ToTensor(), | |
# transforms.Normalize(mean=means, std=stds), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
# transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1) | |
] | |
self.trans = transforms.Compose(trans) | |
def __len__(self): | |
return len(self.X) | |
def __getitem__(self, index): | |
if type(index) is slice: | |
if index.step is None: | |
return (torch.stack([self.get_one_X(i) for i in range(index.start, index.stop)]), | |
torch.stack([self.get_one_y(i) for i in range(index.start, index.stop)])) | |
else: | |
return (torch.stack([self.get_one_X(i) for i in range(index.start, index.stop, index.step)]), | |
torch.stack([self.get_one_y(i) for i in range(index.start, index.stop, index.step)])) | |
if type(index) is int: | |
return self.get_one_X(index), self.get_one_y(index) | |
def get_one_X(self, index): | |
f = self.img_dir + self.X[index] | |
X = Image.open(f) | |
X = self.trans(X) | |
return X | |
def get_one_y(self, index): | |
y = self.y.iloc[index].copy() | |
y = torch.tensor(y) | |
y = y.float() | |
y = y.reshape(25,2) | |
return y |