Spaces:
Sleeping
Sleeping
File size: 7,151 Bytes
02ba63a 19d010a c6793c3 19d010a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import torch
from torch import Tensor, nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
class RandomAffineAndRetMat(torch.nn.Module):
def __init__(
self,
degrees,
translate=None,
scale=None,
shear=None,
interpolation=torchvision.transforms.InterpolationMode.NEAREST,
fill=0,
center=None,
):
super().__init__()
self.degrees = degrees
self.translate = translate
self.scale = scale
self.shear = shear
self.interpolation = interpolation
self.fill = fill
self.center = center
def forward(self, img):
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Affine transformed image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * transforms.functional.get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
img_size = transforms.functional.get_image_size(img)
ret = transforms.RandomAffine.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
transformed_image = transforms.functional.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
affine_matrix = self.get_affine_matrix_from_params(ret)
return transformed_image, affine_matrix
def get_affine_matrix_from_params(self, params):
degrees, translate, scale, shear = params
degrees = torch.tensor(degrees)
shear = torch.tensor(shear)
# パラメータを変換行列に変換
rotation_matrix = torch.tensor([[torch.cos(torch.deg2rad(degrees)), -torch.sin(torch.deg2rad(degrees)), 0],
[torch.sin(torch.deg2rad(degrees)), torch.cos(torch.deg2rad(degrees)), 0],
[0, 0, 1]])
translation_matrix = torch.tensor([[1, 0, translate[0]],
[0, 1, translate[1]],
[0, 0, 1]]).to(torch.float32)
scaling_matrix = torch.tensor([[scale, 0, 0],
[0, scale, 0],
[0, 0, 1]])
shearing_matrix = torch.tensor([[1, -torch.tan(torch.deg2rad(shear[0])), 0],
[-torch.tan(torch.deg2rad(shear[1])), 1, 0],
[0, 0, 1]])
# 変換行列を合成
affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix)
return affine_matrix
def norm_img(img):
return (img-img.min())/(img.max()-img.min())
def preprocess_uploaded_image(uploaded_image, image_size):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ndarrayの場合はPILイメージに変換
if type(uploaded_image) == np.ndarray:
uploaded_image = Image.fromarray(uploaded_image)
uploaded_image = uploaded_image.convert("RGB")
uploaded_image = uploaded_image.resize((image_size, image_size))
uploaded_image = np.array(uploaded_image).transpose(2, 0, 1) / 255.0
uploaded_image = torch.tensor(uploaded_image, dtype=torch.float32).unsqueeze(0).to(device)
return uploaded_image
def get_heatmaps(img, feature_map, source_num, x_coords, y_coords, uploaded_image):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = img.size(2)
batch_size = feature_map.size(0)
feature_dim = feature_map.size(1)
target_num = batch_size - 1
x_coords = [x_coords] * batch_size
y_coords = [y_coords] * batch_size
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords]
vector = vectors[source_num]
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
norm_batch_distance_map = 1 / torch.cosh(20 * (batch_distance_map - batch_distance_map.min()) / (batch_distance_map.max() - batch_distance_map.min())) ** 2
source_map = norm_batch_distance_map[source_num].detach().cpu()
target_map = norm_batch_distance_map[target_num].detach().cpu()
alpha = 0.7
blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num] / norm_batch_distance_map[source_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num] / norm_batch_distance_map[target_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
blended_source = blended_source.detach().cpu()
blended_target = blended_target.detach().cpu()
return source_map, target_map, blended_source, blended_target
def get_mean_vector(feature_map, points):
keypoints_size = points.size(1)
mean_vector_list = []
for i in range(keypoints_size):
x_coords, y_coords = torch.round(points[:,i].t()).to(torch.long)
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
# mean_vector = vectors[0:10].mean(0) # 10個の特徴マップの平均ベクトルを取得
mean_vector = vectors.mean(0).detach().cpu().numpy()
mean_vector_list.append(mean_vector)
return mean_vector_list
def get_keypoint_heatmaps(feature_map, mean_vector_list, keypoints_size, imgs):
if len(feature_map.size()) == 3:
feature_map = feature_map.unsqueeze(0)
device = feature_map.device
batch_size = feature_map.size(0)
feature_dim = feature_map.size(1)
size = feature_map.size(2)
norm_batch_distance_map = torch.zeros(batch_size,size,size,device=device)
for i in range(keypoints_size):
vector = mean_vector_list[i]
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), size, size)
batch_distance_map = 1/torch.cosh( 40*(batch_distance_map-batch_distance_map.min())
/(batch_distance_map.max()-batch_distance_map.min()) )**2
# 正規化
m = batch_distance_map/batch_distance_map.max(1).values.max(1).values.unsqueeze(0).unsqueeze(0).repeat(112,112,1).permute(2,0,1)
norm_batch_distance_map += m
# 1以上を消す
norm_batch_distance_map = (-F.relu(-norm_batch_distance_map+1)+1)
keypoint_maps = norm_batch_distance_map.detach().cpu()
alpha = 0.8 # Transparency factor for the heatmap overlay
blended_tensors = (1 - alpha) * imgs + alpha * torch.cat(
(norm_batch_distance_map.unsqueeze(1), torch.zeros(batch_size,2,size,size,device=device)),
dim=1
)
blended_tensors = norm_img(blended_tensors).detach().cpu()
return keypoint_maps, blended_tensors |