Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,616 Bytes
02ba63a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
import torch
from torch import Tensor, nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
class RandomAffineAndRetMat(torch.nn.Module):
def __init__(
self,
degrees,
translate=None,
scale=None,
shear=None,
interpolation=torchvision.transforms.InterpolationMode.NEAREST,
fill=0,
center=None,
):
super().__init__()
self.degrees = degrees
self.translate = translate
self.scale = scale
self.shear = shear
self.interpolation = interpolation
self.fill = fill
self.center = center
def forward(self, img):
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Affine transformed image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * transforms.functional.get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
img_size = transforms.functional.get_image_size(img)
ret = transforms.RandomAffine.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
transformed_image = transforms.functional.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
affine_matrix = self.get_affine_matrix_from_params(ret)
return transformed_image, affine_matrix
def get_affine_matrix_from_params(self, params):
degrees, translate, scale, shear = params
degrees = torch.tensor(degrees)
shear = torch.tensor(shear)
# パラメータを変換行列に変換
rotation_matrix = torch.tensor([[torch.cos(torch.deg2rad(degrees)), -torch.sin(torch.deg2rad(degrees)), 0],
[torch.sin(torch.deg2rad(degrees)), torch.cos(torch.deg2rad(degrees)), 0],
[0, 0, 1]])
translation_matrix = torch.tensor([[1, 0, translate[0]],
[0, 1, translate[1]],
[0, 0, 1]]).to(torch.float32)
scaling_matrix = torch.tensor([[scale, 0, 0],
[0, scale, 0],
[0, 0, 1]])
shearing_matrix = torch.tensor([[1, -torch.tan(torch.deg2rad(shear[0])), 0],
[-torch.tan(torch.deg2rad(shear[1])), 1, 0],
[0, 0, 1]])
# 変換行列を合成
affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix)
return affine_matrix
class GetTransformedCoords(nn.Module):
def __init__(self, affine_matrix, center):
super().__init__()
self.affine_matrix = affine_matrix
self.center = center
def forward(self, _coords):
# coords: like tensor([[43, 26], [44, 27], [45, 28]])
center_x, center_y = self.center
# 元の座標を中心原点にシフト
coords = _coords.clone()
coords[:, 0] -= center_x
coords[:, 1] -= center_y
# 各バッチに対して変換を行う
homogeneous_coordinates = torch.cat([coords, torch.ones(coords.shape[0], 1, dtype=torch.float32, device=coords.device)], dim=1)
transformed_coordinates = torch.bmm(self.affine_matrix, homogeneous_coordinates.unsqueeze(-1)).squeeze(-1)
# 画像の範囲内に収める
# transformed_x = max(0, min(width - 1, transformed_coordinates[:, 0]))
# transformed_y = max(0, min(height - 1, transformed_coordinates[:, 1]))
transformed_x = transformed_coordinates[:, 0]
transformed_y = transformed_coordinates[:, 1]
transformed_x += center_x
transformed_y += center_y
return torch.stack([transformed_x, transformed_y]).t().to(torch.long)
# ルートを取らないpairwise_distanceのバージョン
def pairwise_distance_squared(a, b):
return torch.sum((a - b) ** 2, dim=-1)
def cosine_similarity(a, b):
# ベクトルaとbの内積を計算
dot_product = torch.matmul(a, b)
# ベクトルaとbのノルム(大きさ)を計算
norm_a = torch.sqrt(torch.sum(a ** 2, dim=-1))
norm_b = torch.sqrt(torch.sum(b ** 2, dim=-1))
# コサイン類似度を計算(内積をノルムの積で割る)
return dot_product / (norm_a * norm_b)
def batch_cosine_similarity(a, b):
# ベクトルaとbの内積を計算
dot_product = torch.einsum('bnd,bnd->bn', a, b)
# ベクトルaとbのノルム(大きさ)を計算
norm_a = torch.sqrt(torch.sum(a ** 2, dim=-1))
norm_b = torch.sqrt(torch.sum(b ** 2, dim=-1))
# コサイン類似度を計算(内積をノルムの積で割る)
return dot_product / (norm_a * norm_b)
class TripletLossBatch(nn.Module):
def __init__(self):
super(TripletLossBatch, self).__init__()
def forward(self, anchor, positive, negative, margin=1.0):
distance_positive = F.pairwise_distance(anchor, positive, p=2)
distance_negative = F.pairwise_distance(anchor, negative, p=2)
losses = torch.relu(distance_positive - distance_negative + margin)
return losses.mean()
class TripletLossCosineSimilarity(nn.Module):
def __init__(self):
super(TripletLossCosineSimilarity, self).__init__()
def forward(self, anchor, positive, negative, margin=1.0):
distance_positive = 1 - batch_cosine_similarity(anchor, positive)
distance_negative = 1 - batch_cosine_similarity(anchor, negative)
losses = torch.relu(distance_positive - distance_negative + margin)
return losses.mean()
def imsave(img):
img = torchvision.utils.make_grid(img)
img = img / 2 + 0.5
npimg = img.detach().cpu().numpy()
# plt.imshow(np.transpose(npimg, (1, 2, 0)))
# plt.show()
# save image
npimg = np.transpose(npimg, (1, 2, 0))
npimg = npimg * 255
npimg = npimg.astype(np.uint8)
Image.fromarray(npimg).save('sample.png')
def norm_img(img):
return (img-img.min())/(img.max()-img.min())
def norm_img2(img):
return (img-img.min())/(img.max()-img.min())*255
class DistanceMapLogger:
def __call__(self, img, feature_map, save_path, x_coords=None, y_coords=None):
device = feature_map.device
batch_size = feature_map.size(0)
feature_dim = feature_map.size(1)
image_size = feature_map.size(2)
if x_coords is None:
x_coords = [69]*batch_size
if y_coords is None:
y_coords = [42]*batch_size
# PCAで3次元のマップを抽出
pca = PCA(n_components=3)
pca_result = pca.fit_transform(feature_map.permute(0,2,3,1).reshape(-1,feature_dim).detach().cpu().numpy()) # PCA を実行
reshaped_pca_result = pca_result.reshape(batch_size,image_size,image_size,3) # 3次元に変換(元は1次元)
sample_num = 0
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
vector = vectors[sample_num]
# バッチ内の各特徴マップに対して内積を計算
# feature_mapの次元を並べ替えてバッチと高さ・幅を平坦化
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
# batch_distance_map = F.cosine_similarity(reshaped_feature_map, vector.unsqueeze(0).unsqueeze(0).expand(65,size*size,32), dim=2).permute(1, 0).reshape(feature_map.size(0), size, size)
norm_batch_distance_map = 1/torch.cosh( 20*(batch_distance_map-batch_distance_map.min())/(batch_distance_map.max()-batch_distance_map.min()) )**2
# norm_batch_distance_map[:,0,0] = 0.001
# 可視化と保存
fig, axes = plt.subplots(5, 4, figsize=(20, 25))
for ax in axes.flatten():
ax.axis('off')
# 余白をなくす
plt.subplots_adjust(wspace=0, hspace=0)
# 外の余白もなくす
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
# 距離マップの可視化
for i in range(5):
axes[i, 0].imshow(norm_batch_distance_map[i].detach().cpu(), cmap='hot')
if i == sample_num:
axes[i, 0].scatter(x_coords[i], y_coords[i], c='b', s=7)
distance_map = torch.cat(((norm_batch_distance_map[i]/norm_batch_distance_map[i].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
alpha = 0.9 # Transparency factor for the heatmap overlay
blended_tensor = (1 - alpha) * img[i] + alpha * distance_map
axes[i, 1].imshow(norm_img(blended_tensor.permute(1,2,0).detach().cpu()))
axes[i, 2].imshow(norm_img(img[i].permute(1,2,0).detach().cpu()))
axes[i, 3].imshow(norm_img(reshaped_pca_result[i]))
plt.savefig(save_path)
def get_heatmaps(self, img, feature_map, source_num=0, target_num=1, x_coords=69, y_coords=42):
device = feature_map.device
batch_size = feature_map.size(0)
feature_dim = feature_map.size(1)
image_size = feature_map.size(2)
x_coords = [x_coords]*batch_size
y_coords = [y_coords]*batch_size
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
vector = vectors[source_num]
# バッチ内の各特徴マップに対して内積を計算
# feature_mapの次元を並べ替えてバッチと高さ・幅を平坦化
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
# batch_distance_map = F.cosine_similarity(reshaped_feature_map, vector.unsqueeze(0).unsqueeze(0).expand(65,size*size,32), dim=2).permute(1, 0).reshape(feature_map.size(0), size, size)
norm_batch_distance_map = 1/torch.cosh( 20*(batch_distance_map-batch_distance_map.min())/(batch_distance_map.max()-batch_distance_map.min()) )**2
# norm_batch_distance_map[:,0,0] = 0.001
source_map = norm_batch_distance_map[source_num]
target_map = norm_batch_distance_map[target_num]
alpha = 0.9
blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num]/norm_batch_distance_map[source_num].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num]/norm_batch_distance_map[target_num].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
return source_map, target_map, blended_source, blended_target |