yeq6x's picture
init
02ba63a
raw
history blame
10.6 kB
import torch
from torch import Tensor, nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
class RandomAffineAndRetMat(torch.nn.Module):
def __init__(
self,
degrees,
translate=None,
scale=None,
shear=None,
interpolation=torchvision.transforms.InterpolationMode.NEAREST,
fill=0,
center=None,
):
super().__init__()
self.degrees = degrees
self.translate = translate
self.scale = scale
self.shear = shear
self.interpolation = interpolation
self.fill = fill
self.center = center
def forward(self, img):
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Affine transformed image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * transforms.functional.get_image_num_channels(img)
else:
fill = [float(f) for f in fill]
img_size = transforms.functional.get_image_size(img)
ret = transforms.RandomAffine.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
transformed_image = transforms.functional.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
affine_matrix = self.get_affine_matrix_from_params(ret)
return transformed_image, affine_matrix
def get_affine_matrix_from_params(self, params):
degrees, translate, scale, shear = params
degrees = torch.tensor(degrees)
shear = torch.tensor(shear)
# パラメータを変換行列に変換
rotation_matrix = torch.tensor([[torch.cos(torch.deg2rad(degrees)), -torch.sin(torch.deg2rad(degrees)), 0],
[torch.sin(torch.deg2rad(degrees)), torch.cos(torch.deg2rad(degrees)), 0],
[0, 0, 1]])
translation_matrix = torch.tensor([[1, 0, translate[0]],
[0, 1, translate[1]],
[0, 0, 1]]).to(torch.float32)
scaling_matrix = torch.tensor([[scale, 0, 0],
[0, scale, 0],
[0, 0, 1]])
shearing_matrix = torch.tensor([[1, -torch.tan(torch.deg2rad(shear[0])), 0],
[-torch.tan(torch.deg2rad(shear[1])), 1, 0],
[0, 0, 1]])
# 変換行列を合成
affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix)
return affine_matrix
class GetTransformedCoords(nn.Module):
def __init__(self, affine_matrix, center):
super().__init__()
self.affine_matrix = affine_matrix
self.center = center
def forward(self, _coords):
# coords: like tensor([[43, 26], [44, 27], [45, 28]])
center_x, center_y = self.center
# 元の座標を中心原点にシフト
coords = _coords.clone()
coords[:, 0] -= center_x
coords[:, 1] -= center_y
# 各バッチに対して変換を行う
homogeneous_coordinates = torch.cat([coords, torch.ones(coords.shape[0], 1, dtype=torch.float32, device=coords.device)], dim=1)
transformed_coordinates = torch.bmm(self.affine_matrix, homogeneous_coordinates.unsqueeze(-1)).squeeze(-1)
# 画像の範囲内に収める
# transformed_x = max(0, min(width - 1, transformed_coordinates[:, 0]))
# transformed_y = max(0, min(height - 1, transformed_coordinates[:, 1]))
transformed_x = transformed_coordinates[:, 0]
transformed_y = transformed_coordinates[:, 1]
transformed_x += center_x
transformed_y += center_y
return torch.stack([transformed_x, transformed_y]).t().to(torch.long)
# ルートを取らないpairwise_distanceのバージョン
def pairwise_distance_squared(a, b):
return torch.sum((a - b) ** 2, dim=-1)
def cosine_similarity(a, b):
# ベクトルaとbの内積を計算
dot_product = torch.matmul(a, b)
# ベクトルaとbのノルム(大きさ)を計算
norm_a = torch.sqrt(torch.sum(a ** 2, dim=-1))
norm_b = torch.sqrt(torch.sum(b ** 2, dim=-1))
# コサイン類似度を計算(内積をノルムの積で割る)
return dot_product / (norm_a * norm_b)
def batch_cosine_similarity(a, b):
# ベクトルaとbの内積を計算
dot_product = torch.einsum('bnd,bnd->bn', a, b)
# ベクトルaとbのノルム(大きさ)を計算
norm_a = torch.sqrt(torch.sum(a ** 2, dim=-1))
norm_b = torch.sqrt(torch.sum(b ** 2, dim=-1))
# コサイン類似度を計算(内積をノルムの積で割る)
return dot_product / (norm_a * norm_b)
class TripletLossBatch(nn.Module):
def __init__(self):
super(TripletLossBatch, self).__init__()
def forward(self, anchor, positive, negative, margin=1.0):
distance_positive = F.pairwise_distance(anchor, positive, p=2)
distance_negative = F.pairwise_distance(anchor, negative, p=2)
losses = torch.relu(distance_positive - distance_negative + margin)
return losses.mean()
class TripletLossCosineSimilarity(nn.Module):
def __init__(self):
super(TripletLossCosineSimilarity, self).__init__()
def forward(self, anchor, positive, negative, margin=1.0):
distance_positive = 1 - batch_cosine_similarity(anchor, positive)
distance_negative = 1 - batch_cosine_similarity(anchor, negative)
losses = torch.relu(distance_positive - distance_negative + margin)
return losses.mean()
def imsave(img):
img = torchvision.utils.make_grid(img)
img = img / 2 + 0.5
npimg = img.detach().cpu().numpy()
# plt.imshow(np.transpose(npimg, (1, 2, 0)))
# plt.show()
# save image
npimg = np.transpose(npimg, (1, 2, 0))
npimg = npimg * 255
npimg = npimg.astype(np.uint8)
Image.fromarray(npimg).save('sample.png')
def norm_img(img):
return (img-img.min())/(img.max()-img.min())
def norm_img2(img):
return (img-img.min())/(img.max()-img.min())*255
class DistanceMapLogger:
def __call__(self, img, feature_map, save_path, x_coords=None, y_coords=None):
device = feature_map.device
batch_size = feature_map.size(0)
feature_dim = feature_map.size(1)
image_size = feature_map.size(2)
if x_coords is None:
x_coords = [69]*batch_size
if y_coords is None:
y_coords = [42]*batch_size
# PCAで3次元のマップを抽出
pca = PCA(n_components=3)
pca_result = pca.fit_transform(feature_map.permute(0,2,3,1).reshape(-1,feature_dim).detach().cpu().numpy()) # PCA を実行
reshaped_pca_result = pca_result.reshape(batch_size,image_size,image_size,3) # 3次元に変換(元は1次元)
sample_num = 0
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
vector = vectors[sample_num]
# バッチ内の各特徴マップに対して内積を計算
# feature_mapの次元を並べ替えてバッチと高さ・幅を平坦化
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
# batch_distance_map = F.cosine_similarity(reshaped_feature_map, vector.unsqueeze(0).unsqueeze(0).expand(65,size*size,32), dim=2).permute(1, 0).reshape(feature_map.size(0), size, size)
norm_batch_distance_map = 1/torch.cosh( 20*(batch_distance_map-batch_distance_map.min())/(batch_distance_map.max()-batch_distance_map.min()) )**2
# norm_batch_distance_map[:,0,0] = 0.001
# 可視化と保存
fig, axes = plt.subplots(5, 4, figsize=(20, 25))
for ax in axes.flatten():
ax.axis('off')
# 余白をなくす
plt.subplots_adjust(wspace=0, hspace=0)
# 外の余白もなくす
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
# 距離マップの可視化
for i in range(5):
axes[i, 0].imshow(norm_batch_distance_map[i].detach().cpu(), cmap='hot')
if i == sample_num:
axes[i, 0].scatter(x_coords[i], y_coords[i], c='b', s=7)
distance_map = torch.cat(((norm_batch_distance_map[i]/norm_batch_distance_map[i].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
alpha = 0.9 # Transparency factor for the heatmap overlay
blended_tensor = (1 - alpha) * img[i] + alpha * distance_map
axes[i, 1].imshow(norm_img(blended_tensor.permute(1,2,0).detach().cpu()))
axes[i, 2].imshow(norm_img(img[i].permute(1,2,0).detach().cpu()))
axes[i, 3].imshow(norm_img(reshaped_pca_result[i]))
plt.savefig(save_path)
def get_heatmaps(self, img, feature_map, source_num=0, target_num=1, x_coords=69, y_coords=42):
device = feature_map.device
batch_size = feature_map.size(0)
feature_dim = feature_map.size(1)
image_size = feature_map.size(2)
x_coords = [x_coords]*batch_size
y_coords = [y_coords]*batch_size
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
vector = vectors[source_num]
# バッチ内の各特徴マップに対して内積を計算
# feature_mapの次元を並べ替えてバッチと高さ・幅を平坦化
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
# batch_distance_map = F.cosine_similarity(reshaped_feature_map, vector.unsqueeze(0).unsqueeze(0).expand(65,size*size,32), dim=2).permute(1, 0).reshape(feature_map.size(0), size, size)
norm_batch_distance_map = 1/torch.cosh( 20*(batch_distance_map-batch_distance_map.min())/(batch_distance_map.max()-batch_distance_map.min()) )**2
# norm_batch_distance_map[:,0,0] = 0.001
source_map = norm_batch_distance_map[source_num]
target_map = norm_batch_distance_map[target_num]
alpha = 0.9
blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num]/norm_batch_distance_map[source_num].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num]/norm_batch_distance_map[target_num].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
return source_map, target_map, blended_source, blended_target