yeq6x commited on
Commit
9d42859
·
1 Parent(s): d7a7c6c

refactoring

Browse files
Files changed (4) hide show
  1. app.py +11 -16
  2. dataset.py +0 -21
  3. model_module.py +1 -135
  4. utils.py +0 -173
app.py CHANGED
@@ -6,7 +6,6 @@ from torch.utils.data import DataLoader
6
  import matplotlib.pyplot as plt
7
  from model_module import AutoencoderModule
8
  from dataset import MyDataset, load_filenames
9
- from utils import DistanceMapLogger
10
  import numpy as np
11
  from PIL import Image
12
  import base64
@@ -159,7 +158,6 @@ async () => {
159
  document.querySelector("#input_file").click();
160
  };
161
 
162
- // GradioのFileコンポーネントから画像を読み込む
163
  document.querySelector("#input_file").addEventListener("change", function(e) {
164
  const files = e.target.files;
165
  console.log(files);
@@ -182,7 +180,6 @@ async () => {
182
  }
183
  });
184
 
185
- // GradioボタンにJavaScriptの機能を追加
186
  document.getElementById("crop_button").onclick = function() {
187
  if (cropper) {
188
  const canvas = cropper.getCroppedCanvas();
@@ -227,13 +224,9 @@ with gr.Blocks() as demo:
227
  gr.HTML('<input type="file" id="input_file" style="display:none;">')
228
  input_file_button = gr.Button("Upload and Crop Image", elem_id="input_file_button", variant="primary")
229
  crop_button = gr.Button("Crop", elem_id="crop_button", variant="primary")
230
- # 画像を表示するためのHTML画像タグをGradioで表示
231
  gr.HTML('<img id="crop_view" style="max-width:100%;">')
232
- # Gradioのボタンコンポーネントを追加し、IDを付与
233
- # クロップされた画像データのテキストボックス(Base64データ)
234
  cropped_image_data = gr.Textbox(visible=False, elem_id="cropped_image_data")
235
- input_image = gr.Image(label="Cropped Image", elem_id="input_image")
236
- # cropped_image_dataが更新されたらprocess_imageを呼び出す
237
  cropped_image_data.change(process_image, inputs=cropped_image_data, outputs=input_image)
238
 
239
  # examples
@@ -248,15 +241,17 @@ with gr.Blocks() as demo:
248
  with gr.Column():
249
  output_plot = gr.Plot()
250
 
 
 
 
 
251
 
252
- # Gradioインターフェースの代わり
253
- source_num.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
254
- x_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
255
- y_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
256
- input_image.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
257
 
258
- # JavaScriptコードをロード
259
- demo.load(None, None, None, js=scripts)
260
-
261
  demo.launch()
262
 
 
6
  import matplotlib.pyplot as plt
7
  from model_module import AutoencoderModule
8
  from dataset import MyDataset, load_filenames
 
9
  import numpy as np
10
  from PIL import Image
11
  import base64
 
158
  document.querySelector("#input_file").click();
159
  };
160
 
 
161
  document.querySelector("#input_file").addEventListener("change", function(e) {
162
  const files = e.target.files;
163
  console.log(files);
 
180
  }
181
  });
182
 
 
183
  document.getElementById("crop_button").onclick = function() {
184
  if (cropper) {
185
  const canvas = cropper.getCroppedCanvas();
 
224
  gr.HTML('<input type="file" id="input_file" style="display:none;">')
225
  input_file_button = gr.Button("Upload and Crop Image", elem_id="input_file_button", variant="primary")
226
  crop_button = gr.Button("Crop", elem_id="crop_button", variant="primary")
 
227
  gr.HTML('<img id="crop_view" style="max-width:100%;">')
 
 
228
  cropped_image_data = gr.Textbox(visible=False, elem_id="cropped_image_data")
229
+ input_image = gr.Image(visible=False, label="Cropped Image", elem_id="input_image")
 
230
  cropped_image_data.change(process_image, inputs=cropped_image_data, outputs=input_image)
231
 
232
  # examples
 
241
  with gr.Column():
242
  output_plot = gr.Plot()
243
 
244
+ # ヒートマップの更新
245
+ source_num.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
246
+ x_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
247
+ y_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
248
 
249
+ def change_input_image_hanlder(source_num, x_coords, y_coords, input_image):
250
+ visible = False if input_image is None else True
251
+ return get_heatmaps(source_num, x_coords, y_coords, input_image), gr.Image(visible=visible, label="Cropped Image", elem_id="input_image")
252
+ input_image.change(change_input_image_hanlder, inputs=[source_num, x_coords, y_coords, input_image], outputs=[output_plot, input_image])
 
253
 
254
+ # JavaScriptコードをロード
255
+ demo.load(None, None, None, js=scripts)
 
256
  demo.launch()
257
 
dataset.py CHANGED
@@ -8,27 +8,6 @@ import os
8
  from utils import RandomAffineAndRetMat
9
 
10
  def load_filenames(data_dir):
11
- # label_data = pd.read_json(INPUT_DIR+'DataList.json')
12
- # label_data = label_data.sort_index()
13
- # tmp_points = []
14
- # filenames = []
15
-
16
- # for o in tqdm(label_data.data[0:1000]):
17
- # filenames.append(o['filename'])
18
- # a = o['filename']
19
-
20
- # tmps = []
21
- # for i in range(60):
22
- # tmps.append(o['points'][str(i)]['x'])
23
- # tmps.append(o['points'][str(i)]['y'])
24
- # tmp_points.append(tmps) # datanum
25
-
26
- # filenames = pd.Series(filenames)
27
- # filenames = [str(i).zfill(4)+'.jpg' for i in range(3400)]
28
- # df_points = pd.DataFrame(tmp_points)
29
-
30
- # load from data_dir
31
- # 画像の拡張子のみ
32
  img_exts = ['.jpg', '.jpeg', '.png', '.bmp', '.ppm', '.pgm', '.tif', '.tiff']
33
  filenames = [f for f in os.listdir(data_dir) if os.path.splitext(f)[1].lower() in img_exts]
34
 
 
8
  from utils import RandomAffineAndRetMat
9
 
10
  def load_filenames(data_dir):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  img_exts = ['.jpg', '.jpeg', '.png', '.bmp', '.ppm', '.pgm', '.tif', '.tiff']
12
  filenames = [f for f in os.listdir(data_dir) if os.path.splitext(f)[1].lower() in img_exts]
13
 
model_module.py CHANGED
@@ -1,145 +1,11 @@
1
  import pytorch_lightning as pl
2
- import torch
3
- from torch import nn
4
- from torch.optim import SGD
5
- from torchvision.utils import save_image
6
- import os
7
- from utils import TripletLossBatch, pairwise_distance_squared, GetTransformedCoords, DistanceMapLogger
8
  from model import Autoencoder
9
 
10
  class AutoencoderModule(pl.LightningModule):
11
- def __init__(self, feature_dim=64, learning_rate=0.1, lambda_c=0.97, initial_margin=1.0, initial_threshold=2.0, save_interval=100, output_dir="output_images"):
12
  super(AutoencoderModule, self).__init__()
13
  self.feature_dim = feature_dim
14
- self.learning_rate = learning_rate
15
- self.lambda_c = lambda_c
16
- self.margin_img = initial_margin
17
- self.margin_img_init = initial_margin
18
- self.threshold = initial_threshold
19
  self.model = Autoencoder(self.feature_dim)
20
- self.criterion = nn.MSELoss()
21
- self.triplet_loss = TripletLossBatch()
22
- self.losses = []
23
- self.save_interval = save_interval # バッチごとの出力間隔
24
- self.output_dir = output_dir
25
- os.makedirs(self.output_dir, exist_ok=True)
26
 
27
  def forward(self, x):
28
  return self.model(x)
29
-
30
- def training_step(self, batch, batch_idx):
31
- img, mat, _ = batch
32
- batch_size, _, _, size, size = img.shape
33
- img = img.view(batch_size*2, 3, size, size)
34
- mat = mat.view(batch_size*2, 3, 3)
35
-
36
- dec5_output, output = self.model(img)
37
- mse_loss = self.criterion(output, img)
38
-
39
- # 画像内方向の処理
40
- num_anchor_sets = 2**12
41
- trip_loss = 0
42
- std_list = [2.5*1.025**self.current_epoch, 5*1.025**self.current_epoch]
43
- for c in std_list:
44
- std = size / c
45
- anchors = torch.randint(0, size, (batch_size*2, num_anchor_sets, 1, 2))
46
- coords = anchors + torch.normal(0, std, (batch_size*2, num_anchor_sets, 2, 2)).long()
47
- valid_coords_idx = (((coords >= 0) & (coords < size)).sum(3) == 2).sum(2) != 2
48
- coords[valid_coords_idx] = 0
49
- anchors[valid_coords_idx] = 0
50
-
51
- # 最も近い座標の選択
52
- d = pairwise_distance_squared(anchors.float(), coords.float())
53
- idx = torch.argmin(d, dim=2)
54
- anchors, positives, negatives = self._get_triplet_coordinates(anchors, coords, idx)
55
-
56
- # dec5_outputから特徴ベクトルを抽出
57
- anchor_vectors, positive_vectors, negative_vectors = self._extract_feature_vectors(dec5_output, batch_size, anchors, positives, negatives)
58
- trip_loss += self.triplet_loss(anchor_vectors, positive_vectors, negative_vectors, self.margin_img)
59
-
60
- trip_loss /= len(std_list)
61
- self.margin_img = self.margin_img_init + self.margin_img - trip_loss.detach()
62
-
63
- # 変形の学習
64
- num_samples = 2**20
65
- tf_loss = self._compute_transformation_loss(dec5_output, mat, batch_size, size, num_samples)
66
-
67
- # バッチ方向の処理
68
- bat_dist_loss = self._compute_batch_direction_loss(dec5_output, batch_size, size)
69
-
70
- # 合計損失
71
- loss = mse_loss + trip_loss + 0.001 * bat_dist_loss + (0.001 * 1.**self.current_epoch) * tf_loss
72
- self.log("train_loss", loss)
73
-
74
- # VRAM管理
75
- del img, output
76
- torch.cuda.empty_cache()
77
-
78
- return loss
79
-
80
-
81
- def _get_triplet_coordinates(self, anchors, coords, idx):
82
- anchors = anchors.squeeze(2)
83
- positives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], idx[:, :, None]].squeeze(2)
84
- negatives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], (1 - idx)[:, :, None]].squeeze(2)
85
- return anchors, positives, negatives
86
-
87
- def _extract_feature_vectors(self, dec5_output, batch_size, anchors, positives, negatives):
88
- y_anchors = anchors[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
89
- x_anchors = anchors[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
90
- y_positives = positives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
91
- x_positives = positives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
92
- y_negatives = negatives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
93
- x_negatives = negatives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
94
-
95
- anchor_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_anchors, x_anchors]
96
- positive_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_positives, x_positives]
97
- negative_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_negatives, x_negatives]
98
- return anchor_vectors, positive_vectors, negative_vectors
99
-
100
- def _compute_transformation_loss(self, dec5_output, mat, batch_size, size, num_samples=2**12):
101
- anchor_indices = torch.randint(batch_size, (num_samples, 1), device=self.device).repeat(1, 2).reshape(num_samples*2)
102
- coords_x = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1)
103
- coords_y = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1)
104
- anchor_coords = torch.cat((coords_x, coords_y), 1)
105
- anchor_mat = mat[anchor_indices]
106
- tf_anchor_coords = GetTransformedCoords(anchor_mat, [size/2, size/2])(anchor_coords)
107
-
108
- anchor_vectors = torch.zeros([num_samples*2, self.feature_dim], device=self.device)
109
- inner_idx_flat = ((0 <= tf_anchor_coords[:,0]) & (tf_anchor_coords[:,0] < size)) & ((0 <= tf_anchor_coords[:,1]) & (tf_anchor_coords[:,1] < size))
110
- anchor_vectors[inner_idx_flat] = dec5_output[anchor_indices[inner_idx_flat], :, tf_anchor_coords[inner_idx_flat, 0], tf_anchor_coords[inner_idx_flat, 1]]
111
-
112
- inner_idx_and = inner_idx_flat.view(num_samples, 2).t()[0] & inner_idx_flat.view(num_samples, 2).t()[1]
113
- anchor_vectors = anchor_vectors.view(num_samples, 2, self.feature_dim)[inner_idx_and]
114
- return pairwise_distance_squared(anchor_vectors[:,0], anchor_vectors[:,1]).mean()
115
-
116
- def _compute_batch_direction_loss(self, dec5_output, batch_size, size):
117
- N = 2**12
118
- anchor_indices = torch.randint(0, batch_size, (N,)) * 2 + torch.randint(0, 2, (N,))
119
- anchor_coords = torch.randint(0, size, (N, 2))
120
- other_indices = torch.randint(0, batch_size-1, (N, 2)) * 2 + torch.randint(0, 2, (N, 2))
121
- other_indices += (other_indices >= anchor_indices.unsqueeze(1)).long() * 2
122
- other_coords = torch.randint(0, size, (N, 2, 2))
123
-
124
- anchor_vectors = dec5_output[anchor_indices, :, anchor_coords[:, 0], anchor_coords[:, 1]]
125
- other_vectors = dec5_output[other_indices, :, other_coords[:, :, 0], other_coords[:, :, 1]]
126
- distances = pairwise_distance_squared(anchor_vectors.unsqueeze(1), other_vectors)
127
- return distances[distances < self.threshold].sum() / ((distances < self.threshold).sum() + 1e-10)
128
-
129
- def configure_optimizers(self):
130
- optimizer = SGD(self.parameters(), lr=self.learning_rate)
131
- scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch)
132
- return [optimizer], [scheduler]
133
-
134
- # def save_intermediate_image(self, output, epoch):
135
- # save_image(output[:4], os.path.join(self.output_dir, f"epoch_{epoch}_output.png"), nrow=1)
136
- # print(f"Saved intermediate image at epoch {epoch}")
137
-
138
- # def distance_map(self, _input, feature_map, epoch, x_coords=None, y_coords=None):
139
- # save_path = os.path.join(self.output_dir, f"epoch_{epoch}_distance_map.png")
140
- # DistanceMapLogger()(_input, feature_map, save_path, x_coords, y_coords)
141
-
142
- def configure_optimizers(self):
143
- optimizer = SGD(self.parameters(), lr=self.learning_rate)
144
- scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch)
145
- return [optimizer], [scheduler]
 
1
  import pytorch_lightning as pl
 
 
 
 
 
 
2
  from model import Autoencoder
3
 
4
  class AutoencoderModule(pl.LightningModule):
5
+ def __init__(self, feature_dim=64):
6
  super(AutoencoderModule, self).__init__()
7
  self.feature_dim = feature_dim
 
 
 
 
 
8
  self.model = Autoencoder(self.feature_dim)
 
 
 
 
 
 
9
 
10
  def forward(self, x):
11
  return self.model(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -77,176 +77,3 @@ class RandomAffineAndRetMat(torch.nn.Module):
77
  affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix)
78
 
79
  return affine_matrix
80
-
81
- class GetTransformedCoords(nn.Module):
82
- def __init__(self, affine_matrix, center):
83
- super().__init__()
84
- self.affine_matrix = affine_matrix
85
- self.center = center
86
-
87
- def forward(self, _coords):
88
- # coords: like tensor([[43, 26], [44, 27], [45, 28]])
89
- center_x, center_y = self.center
90
- # 元の座標を中心原点にシフト
91
- coords = _coords.clone()
92
- coords[:, 0] -= center_x
93
- coords[:, 1] -= center_y
94
-
95
- # 各バッチに対して変換を行う
96
- homogeneous_coordinates = torch.cat([coords, torch.ones(coords.shape[0], 1, dtype=torch.float32, device=coords.device)], dim=1)
97
- transformed_coordinates = torch.bmm(self.affine_matrix, homogeneous_coordinates.unsqueeze(-1)).squeeze(-1)
98
-
99
- # 画像の範囲内に収める
100
- # transformed_x = max(0, min(width - 1, transformed_coordinates[:, 0]))
101
- # transformed_y = max(0, min(height - 1, transformed_coordinates[:, 1]))
102
- transformed_x = transformed_coordinates[:, 0]
103
- transformed_y = transformed_coordinates[:, 1]
104
-
105
- transformed_x += center_x
106
- transformed_y += center_y
107
- return torch.stack([transformed_x, transformed_y]).t().to(torch.long)
108
-
109
- # ルートを取らないpairwise_distanceのバージョン
110
- def pairwise_distance_squared(a, b):
111
- return torch.sum((a - b) ** 2, dim=-1)
112
-
113
- def cosine_similarity(a, b):
114
- # ベクトルaとbの内積を計算
115
- dot_product = torch.matmul(a, b)
116
- # ベクトルaとbのノルム(大きさ)を計算
117
- norm_a = torch.sqrt(torch.sum(a ** 2, dim=-1))
118
- norm_b = torch.sqrt(torch.sum(b ** 2, dim=-1))
119
- # コサイン類似度を計算(内積をノルムの積で割る)
120
- return dot_product / (norm_a * norm_b)
121
-
122
- def batch_cosine_similarity(a, b):
123
- # ベクトルaとbの内積を計算
124
- dot_product = torch.einsum('bnd,bnd->bn', a, b)
125
- # ベクトルaとbのノルム(大きさ)を計算
126
- norm_a = torch.sqrt(torch.sum(a ** 2, dim=-1))
127
- norm_b = torch.sqrt(torch.sum(b ** 2, dim=-1))
128
- # コサイン類似度を計算(内積をノルムの積で割る)
129
- return dot_product / (norm_a * norm_b)
130
-
131
- class TripletLossBatch(nn.Module):
132
- def __init__(self):
133
- super(TripletLossBatch, self).__init__()
134
-
135
- def forward(self, anchor, positive, negative, margin=1.0):
136
- distance_positive = F.pairwise_distance(anchor, positive, p=2)
137
- distance_negative = F.pairwise_distance(anchor, negative, p=2)
138
- losses = torch.relu(distance_positive - distance_negative + margin)
139
- return losses.mean()
140
-
141
- class TripletLossCosineSimilarity(nn.Module):
142
- def __init__(self):
143
- super(TripletLossCosineSimilarity, self).__init__()
144
-
145
- def forward(self, anchor, positive, negative, margin=1.0):
146
- distance_positive = 1 - batch_cosine_similarity(anchor, positive)
147
- distance_negative = 1 - batch_cosine_similarity(anchor, negative)
148
- losses = torch.relu(distance_positive - distance_negative + margin)
149
- return losses.mean()
150
-
151
- def imsave(img):
152
- img = torchvision.utils.make_grid(img)
153
- img = img / 2 + 0.5
154
- npimg = img.detach().cpu().numpy()
155
- # plt.imshow(np.transpose(npimg, (1, 2, 0)))
156
- # plt.show()
157
- # save image
158
- npimg = np.transpose(npimg, (1, 2, 0))
159
- npimg = npimg * 255
160
- npimg = npimg.astype(np.uint8)
161
- Image.fromarray(npimg).save('sample.png')
162
-
163
- def norm_img(img):
164
- return (img-img.min())/(img.max()-img.min())
165
-
166
- def norm_img2(img):
167
- return (img-img.min())/(img.max()-img.min())*255
168
-
169
- class DistanceMapLogger:
170
- def __call__(self, img, feature_map, save_path, x_coords=None, y_coords=None):
171
- device = feature_map.device
172
- batch_size = feature_map.size(0)
173
- feature_dim = feature_map.size(1)
174
- image_size = feature_map.size(2)
175
-
176
- if x_coords is None:
177
- x_coords = [69]*batch_size
178
- if y_coords is None:
179
- y_coords = [42]*batch_size
180
-
181
- # PCAで3次元のマップを抽出
182
- pca = PCA(n_components=3)
183
- pca_result = pca.fit_transform(feature_map.permute(0,2,3,1).reshape(-1,feature_dim).detach().cpu().numpy()) # PCA を実行
184
- reshaped_pca_result = pca_result.reshape(batch_size,image_size,image_size,3) # 3次元に変換(元は1次元)
185
-
186
- sample_num = 0
187
- vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
188
- vector = vectors[sample_num]
189
-
190
- # バッチ内の各特徴マップに対して内積を計算
191
- # feature_mapの次元を並べ替えてバッチと高さ・幅を平坦化
192
- reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
193
- batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
194
- # batch_distance_map = F.cosine_similarity(reshaped_feature_map, vector.unsqueeze(0).unsqueeze(0).expand(65,size*size,32), dim=2).permute(1, 0).reshape(feature_map.size(0), size, size)
195
- norm_batch_distance_map = 1/torch.cosh( 20*(batch_distance_map-batch_distance_map.min())/(batch_distance_map.max()-batch_distance_map.min()) )**2
196
- # norm_batch_distance_map[:,0,0] = 0.001
197
- # 可視化と保存
198
- fig, axes = plt.subplots(5, 4, figsize=(20, 25))
199
- for ax in axes.flatten():
200
- ax.axis('off')
201
- # 余白をなくす
202
- plt.subplots_adjust(wspace=0, hspace=0)
203
- # 外の余白もなくす
204
- plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
205
-
206
- # 距離マップの可視化
207
- for i in range(5):
208
- axes[i, 0].imshow(norm_batch_distance_map[i].detach().cpu(), cmap='hot')
209
- if i == sample_num:
210
- axes[i, 0].scatter(x_coords[i], y_coords[i], c='b', s=7)
211
-
212
- distance_map = torch.cat(((norm_batch_distance_map[i]/norm_batch_distance_map[i].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
213
- alpha = 0.9 # Transparency factor for the heatmap overlay
214
- blended_tensor = (1 - alpha) * img[i] + alpha * distance_map
215
- axes[i, 1].imshow(norm_img(blended_tensor.permute(1,2,0).detach().cpu()))
216
-
217
- axes[i, 2].imshow(norm_img(img[i].permute(1,2,0).detach().cpu()))
218
-
219
- axes[i, 3].imshow(norm_img(reshaped_pca_result[i]))
220
-
221
- plt.savefig(save_path)
222
-
223
-
224
-
225
- def get_heatmaps(self, img, feature_map, source_num=0, target_num=1, x_coords=69, y_coords=42):
226
- device = feature_map.device
227
- batch_size = feature_map.size(0)
228
- feature_dim = feature_map.size(1)
229
- image_size = feature_map.size(2)
230
-
231
- x_coords = [x_coords]*batch_size
232
- y_coords = [y_coords]*batch_size
233
-
234
- vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
235
- vector = vectors[source_num]
236
-
237
- # バッチ内の各特徴マップに対して内積を計算
238
- # feature_mapの次元を並べ替えてバッチと高さ・幅を平坦化
239
- reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
240
- batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
241
- # batch_distance_map = F.cosine_similarity(reshaped_feature_map, vector.unsqueeze(0).unsqueeze(0).expand(65,size*size,32), dim=2).permute(1, 0).reshape(feature_map.size(0), size, size)
242
- norm_batch_distance_map = 1/torch.cosh( 20*(batch_distance_map-batch_distance_map.min())/(batch_distance_map.max()-batch_distance_map.min()) )**2
243
- # norm_batch_distance_map[:,0,0] = 0.001
244
-
245
- source_map = norm_batch_distance_map[source_num]
246
- target_map = norm_batch_distance_map[target_num]
247
-
248
- alpha = 0.9
249
- blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num]/norm_batch_distance_map[source_num].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
250
- blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num]/norm_batch_distance_map[target_num].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
251
-
252
- return source_map, target_map, blended_source, blended_target
 
77
  affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix)
78
 
79
  return affine_matrix