File size: 7,151 Bytes
02ba63a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19d010a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6793c3
19d010a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import torch
from torch import Tensor, nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

class RandomAffineAndRetMat(torch.nn.Module):
  def __init__(
      self,
      degrees,
      translate=None,
      scale=None,
      shear=None,
      interpolation=torchvision.transforms.InterpolationMode.NEAREST,
      fill=0,
      center=None,
  ):
    super().__init__()
    self.degrees = degrees
    self.translate = translate
    self.scale = scale
    self.shear = shear
    self.interpolation = interpolation
    self.fill = fill
    self.center = center

  def forward(self, img):
    """
        img (PIL Image or Tensor): Image to be transformed.

    Returns:
        PIL Image or Tensor: Affine transformed image.
    """
    fill = self.fill
    if isinstance(img, Tensor):
        if isinstance(fill, (int, float)):
            fill = [float(fill)] * transforms.functional.get_image_num_channels(img)
        else:
            fill = [float(f) for f in fill]

    img_size = transforms.functional.get_image_size(img)

    ret = transforms.RandomAffine.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
    transformed_image = transforms.functional.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)

    affine_matrix = self.get_affine_matrix_from_params(ret)

    return transformed_image, affine_matrix

  def get_affine_matrix_from_params(self, params):
    degrees, translate, scale, shear = params
    degrees = torch.tensor(degrees)
    shear = torch.tensor(shear)

    # パラメータを変換行列に変換
    rotation_matrix = torch.tensor([[torch.cos(torch.deg2rad(degrees)), -torch.sin(torch.deg2rad(degrees)), 0],
                                    [torch.sin(torch.deg2rad(degrees)), torch.cos(torch.deg2rad(degrees)), 0],
                                    [0, 0, 1]])

    translation_matrix = torch.tensor([[1, 0, translate[0]],
                                       [0, 1, translate[1]],
                                       [0, 0, 1]]).to(torch.float32)

    scaling_matrix = torch.tensor([[scale, 0, 0],
                                   [0, scale, 0],
                                   [0, 0, 1]])

    shearing_matrix = torch.tensor([[1, -torch.tan(torch.deg2rad(shear[0])), 0],
                                    [-torch.tan(torch.deg2rad(shear[1])), 1, 0],
                                    [0, 0, 1]])

    # 変換行列を合成
    affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix)

    return affine_matrix

def norm_img(img):
  return (img-img.min())/(img.max()-img.min())

def preprocess_uploaded_image(uploaded_image, image_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # ndarrayの場合はPILイメージに変換
    if type(uploaded_image) == np.ndarray:
        uploaded_image = Image.fromarray(uploaded_image)
    uploaded_image = uploaded_image.convert("RGB")
    uploaded_image = uploaded_image.resize((image_size, image_size))
    uploaded_image = np.array(uploaded_image).transpose(2, 0, 1) / 255.0
    uploaded_image = torch.tensor(uploaded_image, dtype=torch.float32).unsqueeze(0).to(device)
    return uploaded_image

def get_heatmaps(img, feature_map, source_num, x_coords, y_coords, uploaded_image):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image_size = img.size(2)
    
    batch_size = feature_map.size(0)
    feature_dim = feature_map.size(1)
    
    target_num = batch_size - 1

    x_coords = [x_coords] * batch_size
    y_coords = [y_coords] * batch_size

    vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords]
    vector = vectors[source_num]

    reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
    batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
    
    norm_batch_distance_map = 1 / torch.cosh(20 * (batch_distance_map - batch_distance_map.min()) / (batch_distance_map.max() - batch_distance_map.min())) ** 2

    source_map = norm_batch_distance_map[source_num].detach().cpu()
    target_map = norm_batch_distance_map[target_num].detach().cpu()

    alpha = 0.7
    blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num] / norm_batch_distance_map[source_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
    blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num] / norm_batch_distance_map[target_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
    
    blended_source = blended_source.detach().cpu()
    blended_target = blended_target.detach().cpu()
    
    return source_map, target_map, blended_source, blended_target

def get_mean_vector(feature_map, points):
  keypoints_size = points.size(1)

  mean_vector_list = []
  for i in range(keypoints_size):
      x_coords, y_coords = torch.round(points[:,i].t()).to(torch.long)
      vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords]  # 1次元ベクトルに合わせてサイズを調整
      # mean_vector = vectors[0:10].mean(0) # 10個の特徴マップの平均ベクトルを取得
      mean_vector = vectors.mean(0).detach().cpu().numpy()
      mean_vector_list.append(mean_vector)
  return mean_vector_list

def get_keypoint_heatmaps(feature_map, mean_vector_list, keypoints_size, imgs):
  if len(feature_map.size()) == 3:
    feature_map = feature_map.unsqueeze(0)
  device = feature_map.device
  batch_size = feature_map.size(0)
  feature_dim = feature_map.size(1)
  size = feature_map.size(2)

  norm_batch_distance_map = torch.zeros(batch_size,size,size,device=device)
  for i in range(keypoints_size):
      vector = mean_vector_list[i]
      reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
      
      batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), size, size)
      batch_distance_map = 1/torch.cosh( 40*(batch_distance_map-batch_distance_map.min())
                                              /(batch_distance_map.max()-batch_distance_map.min()) )**2
      # 正規化
      m = batch_distance_map/batch_distance_map.max(1).values.max(1).values.unsqueeze(0).unsqueeze(0).repeat(112,112,1).permute(2,0,1)
      norm_batch_distance_map += m
  # 1以上を消す
  norm_batch_distance_map = (-F.relu(-norm_batch_distance_map+1)+1)
  keypoint_maps = norm_batch_distance_map.detach().cpu()

  alpha = 0.8  # Transparency factor for the heatmap overlay
  blended_tensors = (1 - alpha) * imgs + alpha * torch.cat(
    (norm_batch_distance_map.unsqueeze(1), torch.zeros(batch_size,2,size,size,device=device)),
    dim=1
    )
  blended_tensors = norm_img(blended_tensors).detach().cpu()
  
  return keypoint_maps, blended_tensors