Spaces:
Sleeping
Sleeping
refactoring
Browse files- app.py +11 -16
- dataset.py +0 -21
- model_module.py +1 -135
- utils.py +0 -173
app.py
CHANGED
@@ -6,7 +6,6 @@ from torch.utils.data import DataLoader
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
from model_module import AutoencoderModule
|
8 |
from dataset import MyDataset, load_filenames
|
9 |
-
from utils import DistanceMapLogger
|
10 |
import numpy as np
|
11 |
from PIL import Image
|
12 |
import base64
|
@@ -159,7 +158,6 @@ async () => {
|
|
159 |
document.querySelector("#input_file").click();
|
160 |
};
|
161 |
|
162 |
-
// GradioのFileコンポーネントから画像を読み込む
|
163 |
document.querySelector("#input_file").addEventListener("change", function(e) {
|
164 |
const files = e.target.files;
|
165 |
console.log(files);
|
@@ -182,7 +180,6 @@ async () => {
|
|
182 |
}
|
183 |
});
|
184 |
|
185 |
-
// GradioボタンにJavaScriptの機能を追加
|
186 |
document.getElementById("crop_button").onclick = function() {
|
187 |
if (cropper) {
|
188 |
const canvas = cropper.getCroppedCanvas();
|
@@ -227,13 +224,9 @@ with gr.Blocks() as demo:
|
|
227 |
gr.HTML('<input type="file" id="input_file" style="display:none;">')
|
228 |
input_file_button = gr.Button("Upload and Crop Image", elem_id="input_file_button", variant="primary")
|
229 |
crop_button = gr.Button("Crop", elem_id="crop_button", variant="primary")
|
230 |
-
# 画像を表示するためのHTML画像タグをGradioで表示
|
231 |
gr.HTML('<img id="crop_view" style="max-width:100%;">')
|
232 |
-
# Gradioのボタンコンポーネントを追加し、IDを付与
|
233 |
-
# クロップされた画像データのテキストボックス(Base64データ)
|
234 |
cropped_image_data = gr.Textbox(visible=False, elem_id="cropped_image_data")
|
235 |
-
input_image = gr.Image(label="Cropped Image", elem_id="input_image")
|
236 |
-
# cropped_image_dataが更新されたらprocess_imageを呼び出す
|
237 |
cropped_image_data.change(process_image, inputs=cropped_image_data, outputs=input_image)
|
238 |
|
239 |
# examples
|
@@ -248,15 +241,17 @@ with gr.Blocks() as demo:
|
|
248 |
with gr.Column():
|
249 |
output_plot = gr.Plot()
|
250 |
|
|
|
|
|
|
|
|
|
251 |
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
input_image.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
|
257 |
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
demo.launch()
|
262 |
|
|
|
6 |
import matplotlib.pyplot as plt
|
7 |
from model_module import AutoencoderModule
|
8 |
from dataset import MyDataset, load_filenames
|
|
|
9 |
import numpy as np
|
10 |
from PIL import Image
|
11 |
import base64
|
|
|
158 |
document.querySelector("#input_file").click();
|
159 |
};
|
160 |
|
|
|
161 |
document.querySelector("#input_file").addEventListener("change", function(e) {
|
162 |
const files = e.target.files;
|
163 |
console.log(files);
|
|
|
180 |
}
|
181 |
});
|
182 |
|
|
|
183 |
document.getElementById("crop_button").onclick = function() {
|
184 |
if (cropper) {
|
185 |
const canvas = cropper.getCroppedCanvas();
|
|
|
224 |
gr.HTML('<input type="file" id="input_file" style="display:none;">')
|
225 |
input_file_button = gr.Button("Upload and Crop Image", elem_id="input_file_button", variant="primary")
|
226 |
crop_button = gr.Button("Crop", elem_id="crop_button", variant="primary")
|
|
|
227 |
gr.HTML('<img id="crop_view" style="max-width:100%;">')
|
|
|
|
|
228 |
cropped_image_data = gr.Textbox(visible=False, elem_id="cropped_image_data")
|
229 |
+
input_image = gr.Image(visible=False, label="Cropped Image", elem_id="input_image")
|
|
|
230 |
cropped_image_data.change(process_image, inputs=cropped_image_data, outputs=input_image)
|
231 |
|
232 |
# examples
|
|
|
241 |
with gr.Column():
|
242 |
output_plot = gr.Plot()
|
243 |
|
244 |
+
# ヒートマップの更新
|
245 |
+
source_num.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
|
246 |
+
x_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
|
247 |
+
y_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
|
248 |
|
249 |
+
def change_input_image_hanlder(source_num, x_coords, y_coords, input_image):
|
250 |
+
visible = False if input_image is None else True
|
251 |
+
return get_heatmaps(source_num, x_coords, y_coords, input_image), gr.Image(visible=visible, label="Cropped Image", elem_id="input_image")
|
252 |
+
input_image.change(change_input_image_hanlder, inputs=[source_num, x_coords, y_coords, input_image], outputs=[output_plot, input_image])
|
|
|
253 |
|
254 |
+
# JavaScriptコードをロード
|
255 |
+
demo.load(None, None, None, js=scripts)
|
|
|
256 |
demo.launch()
|
257 |
|
dataset.py
CHANGED
@@ -8,27 +8,6 @@ import os
|
|
8 |
from utils import RandomAffineAndRetMat
|
9 |
|
10 |
def load_filenames(data_dir):
|
11 |
-
# label_data = pd.read_json(INPUT_DIR+'DataList.json')
|
12 |
-
# label_data = label_data.sort_index()
|
13 |
-
# tmp_points = []
|
14 |
-
# filenames = []
|
15 |
-
|
16 |
-
# for o in tqdm(label_data.data[0:1000]):
|
17 |
-
# filenames.append(o['filename'])
|
18 |
-
# a = o['filename']
|
19 |
-
|
20 |
-
# tmps = []
|
21 |
-
# for i in range(60):
|
22 |
-
# tmps.append(o['points'][str(i)]['x'])
|
23 |
-
# tmps.append(o['points'][str(i)]['y'])
|
24 |
-
# tmp_points.append(tmps) # datanum
|
25 |
-
|
26 |
-
# filenames = pd.Series(filenames)
|
27 |
-
# filenames = [str(i).zfill(4)+'.jpg' for i in range(3400)]
|
28 |
-
# df_points = pd.DataFrame(tmp_points)
|
29 |
-
|
30 |
-
# load from data_dir
|
31 |
-
# 画像の拡張子のみ
|
32 |
img_exts = ['.jpg', '.jpeg', '.png', '.bmp', '.ppm', '.pgm', '.tif', '.tiff']
|
33 |
filenames = [f for f in os.listdir(data_dir) if os.path.splitext(f)[1].lower() in img_exts]
|
34 |
|
|
|
8 |
from utils import RandomAffineAndRetMat
|
9 |
|
10 |
def load_filenames(data_dir):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
img_exts = ['.jpg', '.jpeg', '.png', '.bmp', '.ppm', '.pgm', '.tif', '.tiff']
|
12 |
filenames = [f for f in os.listdir(data_dir) if os.path.splitext(f)[1].lower() in img_exts]
|
13 |
|
model_module.py
CHANGED
@@ -1,145 +1,11 @@
|
|
1 |
import pytorch_lightning as pl
|
2 |
-
import torch
|
3 |
-
from torch import nn
|
4 |
-
from torch.optim import SGD
|
5 |
-
from torchvision.utils import save_image
|
6 |
-
import os
|
7 |
-
from utils import TripletLossBatch, pairwise_distance_squared, GetTransformedCoords, DistanceMapLogger
|
8 |
from model import Autoencoder
|
9 |
|
10 |
class AutoencoderModule(pl.LightningModule):
|
11 |
-
def __init__(self, feature_dim=64
|
12 |
super(AutoencoderModule, self).__init__()
|
13 |
self.feature_dim = feature_dim
|
14 |
-
self.learning_rate = learning_rate
|
15 |
-
self.lambda_c = lambda_c
|
16 |
-
self.margin_img = initial_margin
|
17 |
-
self.margin_img_init = initial_margin
|
18 |
-
self.threshold = initial_threshold
|
19 |
self.model = Autoencoder(self.feature_dim)
|
20 |
-
self.criterion = nn.MSELoss()
|
21 |
-
self.triplet_loss = TripletLossBatch()
|
22 |
-
self.losses = []
|
23 |
-
self.save_interval = save_interval # バッチごとの出力間隔
|
24 |
-
self.output_dir = output_dir
|
25 |
-
os.makedirs(self.output_dir, exist_ok=True)
|
26 |
|
27 |
def forward(self, x):
|
28 |
return self.model(x)
|
29 |
-
|
30 |
-
def training_step(self, batch, batch_idx):
|
31 |
-
img, mat, _ = batch
|
32 |
-
batch_size, _, _, size, size = img.shape
|
33 |
-
img = img.view(batch_size*2, 3, size, size)
|
34 |
-
mat = mat.view(batch_size*2, 3, 3)
|
35 |
-
|
36 |
-
dec5_output, output = self.model(img)
|
37 |
-
mse_loss = self.criterion(output, img)
|
38 |
-
|
39 |
-
# 画像内方向の処理
|
40 |
-
num_anchor_sets = 2**12
|
41 |
-
trip_loss = 0
|
42 |
-
std_list = [2.5*1.025**self.current_epoch, 5*1.025**self.current_epoch]
|
43 |
-
for c in std_list:
|
44 |
-
std = size / c
|
45 |
-
anchors = torch.randint(0, size, (batch_size*2, num_anchor_sets, 1, 2))
|
46 |
-
coords = anchors + torch.normal(0, std, (batch_size*2, num_anchor_sets, 2, 2)).long()
|
47 |
-
valid_coords_idx = (((coords >= 0) & (coords < size)).sum(3) == 2).sum(2) != 2
|
48 |
-
coords[valid_coords_idx] = 0
|
49 |
-
anchors[valid_coords_idx] = 0
|
50 |
-
|
51 |
-
# 最も近い座標の選択
|
52 |
-
d = pairwise_distance_squared(anchors.float(), coords.float())
|
53 |
-
idx = torch.argmin(d, dim=2)
|
54 |
-
anchors, positives, negatives = self._get_triplet_coordinates(anchors, coords, idx)
|
55 |
-
|
56 |
-
# dec5_outputから特徴ベクトルを抽出
|
57 |
-
anchor_vectors, positive_vectors, negative_vectors = self._extract_feature_vectors(dec5_output, batch_size, anchors, positives, negatives)
|
58 |
-
trip_loss += self.triplet_loss(anchor_vectors, positive_vectors, negative_vectors, self.margin_img)
|
59 |
-
|
60 |
-
trip_loss /= len(std_list)
|
61 |
-
self.margin_img = self.margin_img_init + self.margin_img - trip_loss.detach()
|
62 |
-
|
63 |
-
# 変形の学習
|
64 |
-
num_samples = 2**20
|
65 |
-
tf_loss = self._compute_transformation_loss(dec5_output, mat, batch_size, size, num_samples)
|
66 |
-
|
67 |
-
# バッチ方向の処理
|
68 |
-
bat_dist_loss = self._compute_batch_direction_loss(dec5_output, batch_size, size)
|
69 |
-
|
70 |
-
# 合計損失
|
71 |
-
loss = mse_loss + trip_loss + 0.001 * bat_dist_loss + (0.001 * 1.**self.current_epoch) * tf_loss
|
72 |
-
self.log("train_loss", loss)
|
73 |
-
|
74 |
-
# VRAM管理
|
75 |
-
del img, output
|
76 |
-
torch.cuda.empty_cache()
|
77 |
-
|
78 |
-
return loss
|
79 |
-
|
80 |
-
|
81 |
-
def _get_triplet_coordinates(self, anchors, coords, idx):
|
82 |
-
anchors = anchors.squeeze(2)
|
83 |
-
positives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], idx[:, :, None]].squeeze(2)
|
84 |
-
negatives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], (1 - idx)[:, :, None]].squeeze(2)
|
85 |
-
return anchors, positives, negatives
|
86 |
-
|
87 |
-
def _extract_feature_vectors(self, dec5_output, batch_size, anchors, positives, negatives):
|
88 |
-
y_anchors = anchors[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
|
89 |
-
x_anchors = anchors[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
|
90 |
-
y_positives = positives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
|
91 |
-
x_positives = positives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
|
92 |
-
y_negatives = negatives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
|
93 |
-
x_negatives = negatives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
|
94 |
-
|
95 |
-
anchor_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_anchors, x_anchors]
|
96 |
-
positive_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_positives, x_positives]
|
97 |
-
negative_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_negatives, x_negatives]
|
98 |
-
return anchor_vectors, positive_vectors, negative_vectors
|
99 |
-
|
100 |
-
def _compute_transformation_loss(self, dec5_output, mat, batch_size, size, num_samples=2**12):
|
101 |
-
anchor_indices = torch.randint(batch_size, (num_samples, 1), device=self.device).repeat(1, 2).reshape(num_samples*2)
|
102 |
-
coords_x = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1)
|
103 |
-
coords_y = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1)
|
104 |
-
anchor_coords = torch.cat((coords_x, coords_y), 1)
|
105 |
-
anchor_mat = mat[anchor_indices]
|
106 |
-
tf_anchor_coords = GetTransformedCoords(anchor_mat, [size/2, size/2])(anchor_coords)
|
107 |
-
|
108 |
-
anchor_vectors = torch.zeros([num_samples*2, self.feature_dim], device=self.device)
|
109 |
-
inner_idx_flat = ((0 <= tf_anchor_coords[:,0]) & (tf_anchor_coords[:,0] < size)) & ((0 <= tf_anchor_coords[:,1]) & (tf_anchor_coords[:,1] < size))
|
110 |
-
anchor_vectors[inner_idx_flat] = dec5_output[anchor_indices[inner_idx_flat], :, tf_anchor_coords[inner_idx_flat, 0], tf_anchor_coords[inner_idx_flat, 1]]
|
111 |
-
|
112 |
-
inner_idx_and = inner_idx_flat.view(num_samples, 2).t()[0] & inner_idx_flat.view(num_samples, 2).t()[1]
|
113 |
-
anchor_vectors = anchor_vectors.view(num_samples, 2, self.feature_dim)[inner_idx_and]
|
114 |
-
return pairwise_distance_squared(anchor_vectors[:,0], anchor_vectors[:,1]).mean()
|
115 |
-
|
116 |
-
def _compute_batch_direction_loss(self, dec5_output, batch_size, size):
|
117 |
-
N = 2**12
|
118 |
-
anchor_indices = torch.randint(0, batch_size, (N,)) * 2 + torch.randint(0, 2, (N,))
|
119 |
-
anchor_coords = torch.randint(0, size, (N, 2))
|
120 |
-
other_indices = torch.randint(0, batch_size-1, (N, 2)) * 2 + torch.randint(0, 2, (N, 2))
|
121 |
-
other_indices += (other_indices >= anchor_indices.unsqueeze(1)).long() * 2
|
122 |
-
other_coords = torch.randint(0, size, (N, 2, 2))
|
123 |
-
|
124 |
-
anchor_vectors = dec5_output[anchor_indices, :, anchor_coords[:, 0], anchor_coords[:, 1]]
|
125 |
-
other_vectors = dec5_output[other_indices, :, other_coords[:, :, 0], other_coords[:, :, 1]]
|
126 |
-
distances = pairwise_distance_squared(anchor_vectors.unsqueeze(1), other_vectors)
|
127 |
-
return distances[distances < self.threshold].sum() / ((distances < self.threshold).sum() + 1e-10)
|
128 |
-
|
129 |
-
def configure_optimizers(self):
|
130 |
-
optimizer = SGD(self.parameters(), lr=self.learning_rate)
|
131 |
-
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch)
|
132 |
-
return [optimizer], [scheduler]
|
133 |
-
|
134 |
-
# def save_intermediate_image(self, output, epoch):
|
135 |
-
# save_image(output[:4], os.path.join(self.output_dir, f"epoch_{epoch}_output.png"), nrow=1)
|
136 |
-
# print(f"Saved intermediate image at epoch {epoch}")
|
137 |
-
|
138 |
-
# def distance_map(self, _input, feature_map, epoch, x_coords=None, y_coords=None):
|
139 |
-
# save_path = os.path.join(self.output_dir, f"epoch_{epoch}_distance_map.png")
|
140 |
-
# DistanceMapLogger()(_input, feature_map, save_path, x_coords, y_coords)
|
141 |
-
|
142 |
-
def configure_optimizers(self):
|
143 |
-
optimizer = SGD(self.parameters(), lr=self.learning_rate)
|
144 |
-
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch)
|
145 |
-
return [optimizer], [scheduler]
|
|
|
1 |
import pytorch_lightning as pl
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from model import Autoencoder
|
3 |
|
4 |
class AutoencoderModule(pl.LightningModule):
|
5 |
+
def __init__(self, feature_dim=64):
|
6 |
super(AutoencoderModule, self).__init__()
|
7 |
self.feature_dim = feature_dim
|
|
|
|
|
|
|
|
|
|
|
8 |
self.model = Autoencoder(self.feature_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
def forward(self, x):
|
11 |
return self.model(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
CHANGED
@@ -77,176 +77,3 @@ class RandomAffineAndRetMat(torch.nn.Module):
|
|
77 |
affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix)
|
78 |
|
79 |
return affine_matrix
|
80 |
-
|
81 |
-
class GetTransformedCoords(nn.Module):
|
82 |
-
def __init__(self, affine_matrix, center):
|
83 |
-
super().__init__()
|
84 |
-
self.affine_matrix = affine_matrix
|
85 |
-
self.center = center
|
86 |
-
|
87 |
-
def forward(self, _coords):
|
88 |
-
# coords: like tensor([[43, 26], [44, 27], [45, 28]])
|
89 |
-
center_x, center_y = self.center
|
90 |
-
# 元の座標を中心原点にシフト
|
91 |
-
coords = _coords.clone()
|
92 |
-
coords[:, 0] -= center_x
|
93 |
-
coords[:, 1] -= center_y
|
94 |
-
|
95 |
-
# 各バッチに対して変換を行う
|
96 |
-
homogeneous_coordinates = torch.cat([coords, torch.ones(coords.shape[0], 1, dtype=torch.float32, device=coords.device)], dim=1)
|
97 |
-
transformed_coordinates = torch.bmm(self.affine_matrix, homogeneous_coordinates.unsqueeze(-1)).squeeze(-1)
|
98 |
-
|
99 |
-
# 画像の範囲内に収める
|
100 |
-
# transformed_x = max(0, min(width - 1, transformed_coordinates[:, 0]))
|
101 |
-
# transformed_y = max(0, min(height - 1, transformed_coordinates[:, 1]))
|
102 |
-
transformed_x = transformed_coordinates[:, 0]
|
103 |
-
transformed_y = transformed_coordinates[:, 1]
|
104 |
-
|
105 |
-
transformed_x += center_x
|
106 |
-
transformed_y += center_y
|
107 |
-
return torch.stack([transformed_x, transformed_y]).t().to(torch.long)
|
108 |
-
|
109 |
-
# ルートを取らないpairwise_distanceのバージョン
|
110 |
-
def pairwise_distance_squared(a, b):
|
111 |
-
return torch.sum((a - b) ** 2, dim=-1)
|
112 |
-
|
113 |
-
def cosine_similarity(a, b):
|
114 |
-
# ベクトルaとbの内積を計算
|
115 |
-
dot_product = torch.matmul(a, b)
|
116 |
-
# ベクトルaとbのノルム(大きさ)を計算
|
117 |
-
norm_a = torch.sqrt(torch.sum(a ** 2, dim=-1))
|
118 |
-
norm_b = torch.sqrt(torch.sum(b ** 2, dim=-1))
|
119 |
-
# コサイン類似度を計算(内積をノルムの積で割る)
|
120 |
-
return dot_product / (norm_a * norm_b)
|
121 |
-
|
122 |
-
def batch_cosine_similarity(a, b):
|
123 |
-
# ベクトルaとbの内積を計算
|
124 |
-
dot_product = torch.einsum('bnd,bnd->bn', a, b)
|
125 |
-
# ベクトルaとbのノルム(大きさ)を計算
|
126 |
-
norm_a = torch.sqrt(torch.sum(a ** 2, dim=-1))
|
127 |
-
norm_b = torch.sqrt(torch.sum(b ** 2, dim=-1))
|
128 |
-
# コサイン類似度を計算(内積をノルムの積で割る)
|
129 |
-
return dot_product / (norm_a * norm_b)
|
130 |
-
|
131 |
-
class TripletLossBatch(nn.Module):
|
132 |
-
def __init__(self):
|
133 |
-
super(TripletLossBatch, self).__init__()
|
134 |
-
|
135 |
-
def forward(self, anchor, positive, negative, margin=1.0):
|
136 |
-
distance_positive = F.pairwise_distance(anchor, positive, p=2)
|
137 |
-
distance_negative = F.pairwise_distance(anchor, negative, p=2)
|
138 |
-
losses = torch.relu(distance_positive - distance_negative + margin)
|
139 |
-
return losses.mean()
|
140 |
-
|
141 |
-
class TripletLossCosineSimilarity(nn.Module):
|
142 |
-
def __init__(self):
|
143 |
-
super(TripletLossCosineSimilarity, self).__init__()
|
144 |
-
|
145 |
-
def forward(self, anchor, positive, negative, margin=1.0):
|
146 |
-
distance_positive = 1 - batch_cosine_similarity(anchor, positive)
|
147 |
-
distance_negative = 1 - batch_cosine_similarity(anchor, negative)
|
148 |
-
losses = torch.relu(distance_positive - distance_negative + margin)
|
149 |
-
return losses.mean()
|
150 |
-
|
151 |
-
def imsave(img):
|
152 |
-
img = torchvision.utils.make_grid(img)
|
153 |
-
img = img / 2 + 0.5
|
154 |
-
npimg = img.detach().cpu().numpy()
|
155 |
-
# plt.imshow(np.transpose(npimg, (1, 2, 0)))
|
156 |
-
# plt.show()
|
157 |
-
# save image
|
158 |
-
npimg = np.transpose(npimg, (1, 2, 0))
|
159 |
-
npimg = npimg * 255
|
160 |
-
npimg = npimg.astype(np.uint8)
|
161 |
-
Image.fromarray(npimg).save('sample.png')
|
162 |
-
|
163 |
-
def norm_img(img):
|
164 |
-
return (img-img.min())/(img.max()-img.min())
|
165 |
-
|
166 |
-
def norm_img2(img):
|
167 |
-
return (img-img.min())/(img.max()-img.min())*255
|
168 |
-
|
169 |
-
class DistanceMapLogger:
|
170 |
-
def __call__(self, img, feature_map, save_path, x_coords=None, y_coords=None):
|
171 |
-
device = feature_map.device
|
172 |
-
batch_size = feature_map.size(0)
|
173 |
-
feature_dim = feature_map.size(1)
|
174 |
-
image_size = feature_map.size(2)
|
175 |
-
|
176 |
-
if x_coords is None:
|
177 |
-
x_coords = [69]*batch_size
|
178 |
-
if y_coords is None:
|
179 |
-
y_coords = [42]*batch_size
|
180 |
-
|
181 |
-
# PCAで3次元のマップを抽出
|
182 |
-
pca = PCA(n_components=3)
|
183 |
-
pca_result = pca.fit_transform(feature_map.permute(0,2,3,1).reshape(-1,feature_dim).detach().cpu().numpy()) # PCA を実行
|
184 |
-
reshaped_pca_result = pca_result.reshape(batch_size,image_size,image_size,3) # 3次元に変換(元は1次元)
|
185 |
-
|
186 |
-
sample_num = 0
|
187 |
-
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
|
188 |
-
vector = vectors[sample_num]
|
189 |
-
|
190 |
-
# バッチ内の各特徴マップに対して内積を計算
|
191 |
-
# feature_mapの次元を並べ替えてバッチと高さ・幅を平坦化
|
192 |
-
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
|
193 |
-
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
|
194 |
-
# batch_distance_map = F.cosine_similarity(reshaped_feature_map, vector.unsqueeze(0).unsqueeze(0).expand(65,size*size,32), dim=2).permute(1, 0).reshape(feature_map.size(0), size, size)
|
195 |
-
norm_batch_distance_map = 1/torch.cosh( 20*(batch_distance_map-batch_distance_map.min())/(batch_distance_map.max()-batch_distance_map.min()) )**2
|
196 |
-
# norm_batch_distance_map[:,0,0] = 0.001
|
197 |
-
# 可視化と保存
|
198 |
-
fig, axes = plt.subplots(5, 4, figsize=(20, 25))
|
199 |
-
for ax in axes.flatten():
|
200 |
-
ax.axis('off')
|
201 |
-
# 余白をなくす
|
202 |
-
plt.subplots_adjust(wspace=0, hspace=0)
|
203 |
-
# 外の余白もなくす
|
204 |
-
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
205 |
-
|
206 |
-
# 距離マップの可視化
|
207 |
-
for i in range(5):
|
208 |
-
axes[i, 0].imshow(norm_batch_distance_map[i].detach().cpu(), cmap='hot')
|
209 |
-
if i == sample_num:
|
210 |
-
axes[i, 0].scatter(x_coords[i], y_coords[i], c='b', s=7)
|
211 |
-
|
212 |
-
distance_map = torch.cat(((norm_batch_distance_map[i]/norm_batch_distance_map[i].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
|
213 |
-
alpha = 0.9 # Transparency factor for the heatmap overlay
|
214 |
-
blended_tensor = (1 - alpha) * img[i] + alpha * distance_map
|
215 |
-
axes[i, 1].imshow(norm_img(blended_tensor.permute(1,2,0).detach().cpu()))
|
216 |
-
|
217 |
-
axes[i, 2].imshow(norm_img(img[i].permute(1,2,0).detach().cpu()))
|
218 |
-
|
219 |
-
axes[i, 3].imshow(norm_img(reshaped_pca_result[i]))
|
220 |
-
|
221 |
-
plt.savefig(save_path)
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
def get_heatmaps(self, img, feature_map, source_num=0, target_num=1, x_coords=69, y_coords=42):
|
226 |
-
device = feature_map.device
|
227 |
-
batch_size = feature_map.size(0)
|
228 |
-
feature_dim = feature_map.size(1)
|
229 |
-
image_size = feature_map.size(2)
|
230 |
-
|
231 |
-
x_coords = [x_coords]*batch_size
|
232 |
-
y_coords = [y_coords]*batch_size
|
233 |
-
|
234 |
-
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
|
235 |
-
vector = vectors[source_num]
|
236 |
-
|
237 |
-
# バッチ内の各特徴マップに対して内積を計算
|
238 |
-
# feature_mapの次元を並べ替えてバッチと高さ・幅を平坦化
|
239 |
-
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
|
240 |
-
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
|
241 |
-
# batch_distance_map = F.cosine_similarity(reshaped_feature_map, vector.unsqueeze(0).unsqueeze(0).expand(65,size*size,32), dim=2).permute(1, 0).reshape(feature_map.size(0), size, size)
|
242 |
-
norm_batch_distance_map = 1/torch.cosh( 20*(batch_distance_map-batch_distance_map.min())/(batch_distance_map.max()-batch_distance_map.min()) )**2
|
243 |
-
# norm_batch_distance_map[:,0,0] = 0.001
|
244 |
-
|
245 |
-
source_map = norm_batch_distance_map[source_num]
|
246 |
-
target_map = norm_batch_distance_map[target_num]
|
247 |
-
|
248 |
-
alpha = 0.9
|
249 |
-
blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num]/norm_batch_distance_map[source_num].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
|
250 |
-
blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num]/norm_batch_distance_map[target_num].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
|
251 |
-
|
252 |
-
return source_map, target_map, blended_source, blended_target
|
|
|
77 |
affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix)
|
78 |
|
79 |
return affine_matrix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|