Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import Tensor, nn | |
import torch.nn.functional as F | |
import torchvision | |
from torchvision import transforms | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from sklearn.decomposition import PCA | |
class RandomAffineAndRetMat(torch.nn.Module): | |
def __init__( | |
self, | |
degrees, | |
translate=None, | |
scale=None, | |
shear=None, | |
interpolation=torchvision.transforms.InterpolationMode.NEAREST, | |
fill=0, | |
center=None, | |
): | |
super().__init__() | |
self.degrees = degrees | |
self.translate = translate | |
self.scale = scale | |
self.shear = shear | |
self.interpolation = interpolation | |
self.fill = fill | |
self.center = center | |
def forward(self, img): | |
""" | |
img (PIL Image or Tensor): Image to be transformed. | |
Returns: | |
PIL Image or Tensor: Affine transformed image. | |
""" | |
fill = self.fill | |
if isinstance(img, Tensor): | |
if isinstance(fill, (int, float)): | |
fill = [float(fill)] * transforms.functional.get_image_num_channels(img) | |
else: | |
fill = [float(f) for f in fill] | |
img_size = transforms.functional.get_image_size(img) | |
ret = transforms.RandomAffine.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) | |
transformed_image = transforms.functional.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center) | |
affine_matrix = self.get_affine_matrix_from_params(ret) | |
return transformed_image, affine_matrix | |
def get_affine_matrix_from_params(self, params): | |
degrees, translate, scale, shear = params | |
degrees = torch.tensor(degrees) | |
shear = torch.tensor(shear) | |
# パラメータを変換行列に変換 | |
rotation_matrix = torch.tensor([[torch.cos(torch.deg2rad(degrees)), -torch.sin(torch.deg2rad(degrees)), 0], | |
[torch.sin(torch.deg2rad(degrees)), torch.cos(torch.deg2rad(degrees)), 0], | |
[0, 0, 1]]) | |
translation_matrix = torch.tensor([[1, 0, translate[0]], | |
[0, 1, translate[1]], | |
[0, 0, 1]]).to(torch.float32) | |
scaling_matrix = torch.tensor([[scale, 0, 0], | |
[0, scale, 0], | |
[0, 0, 1]]) | |
shearing_matrix = torch.tensor([[1, -torch.tan(torch.deg2rad(shear[0])), 0], | |
[-torch.tan(torch.deg2rad(shear[1])), 1, 0], | |
[0, 0, 1]]) | |
# 変換行列を合成 | |
affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix) | |
return affine_matrix | |
def norm_img(img): | |
return (img-img.min())/(img.max()-img.min()) | |
def preprocess_uploaded_image(uploaded_image, image_size): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# ndarrayの場合はPILイメージに変換 | |
if type(uploaded_image) == np.ndarray: | |
uploaded_image = Image.fromarray(uploaded_image) | |
uploaded_image = uploaded_image.convert("RGB") | |
uploaded_image = uploaded_image.resize((image_size, image_size)) | |
uploaded_image = np.array(uploaded_image).transpose(2, 0, 1) / 255.0 | |
uploaded_image = torch.tensor(uploaded_image, dtype=torch.float32).unsqueeze(0).to(device) | |
return uploaded_image | |
def get_heatmaps(img, feature_map, source_num, x_coords, y_coords, uploaded_image): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
image_size = img.size(2) | |
batch_size = feature_map.size(0) | |
feature_dim = feature_map.size(1) | |
target_num = batch_size - 1 | |
x_coords = [x_coords] * batch_size | |
y_coords = [y_coords] * batch_size | |
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] | |
vector = vectors[source_num] | |
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim) | |
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size) | |
norm_batch_distance_map = 1 / torch.cosh(20 * (batch_distance_map - batch_distance_map.min()) / (batch_distance_map.max() - batch_distance_map.min())) ** 2 | |
source_map = norm_batch_distance_map[source_num].detach().cpu() | |
target_map = norm_batch_distance_map[target_num].detach().cpu() | |
alpha = 0.7 | |
blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num] / norm_batch_distance_map[source_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device))) | |
blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num] / norm_batch_distance_map[target_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device))) | |
blended_source = blended_source.detach().cpu() | |
blended_target = blended_target.detach().cpu() | |
return source_map, target_map, blended_source, blended_target | |
def get_mean_vector(feature_map, points): | |
keypoints_size = points.size(1) | |
mean_vector_list = [] | |
for i in range(keypoints_size): | |
x_coords, y_coords = torch.round(points[:,i].t()).to(torch.long) | |
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整 | |
# mean_vector = vectors[0:10].mean(0) # 10個の特徴マップの平均ベクトルを取得 | |
mean_vector = vectors.mean(0).detach().cpu().numpy() | |
mean_vector_list.append(mean_vector) | |
return mean_vector_list | |
def get_keypoint_heatmaps(feature_map, mean_vector_list, keypoints_size, imgs): | |
if len(feature_map.size()) == 3: | |
feature_map = feature_map.unsqueeze(0) | |
device = feature_map.device | |
batch_size = feature_map.size(0) | |
feature_dim = feature_map.size(1) | |
size = feature_map.size(2) | |
norm_batch_distance_map = torch.zeros(batch_size,size,size,device=device) | |
for i in range(keypoints_size): | |
vector = mean_vector_list[i] | |
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim) | |
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), size, size) | |
batch_distance_map = 1/torch.cosh( 40*(batch_distance_map-batch_distance_map.min()) | |
/(batch_distance_map.max()-batch_distance_map.min()) )**2 | |
# 正規化 | |
m = batch_distance_map/batch_distance_map.max(1).values.max(1).values.unsqueeze(0).unsqueeze(0).repeat(112,112,1).permute(2,0,1) | |
norm_batch_distance_map += m | |
# 1以上を消す | |
norm_batch_distance_map = (-F.relu(-norm_batch_distance_map+1)+1) | |
keypoint_maps = norm_batch_distance_map.detach().cpu() | |
alpha = 0.8 # Transparency factor for the heatmap overlay | |
blended_tensors = (1 - alpha) * imgs + alpha * torch.cat( | |
(norm_batch_distance_map.unsqueeze(1), torch.zeros(batch_size,2,size,size,device=device)), | |
dim=1 | |
) | |
blended_tensors = norm_img(blended_tensors).detach().cpu() | |
return keypoint_maps, blended_tensors |