Spaces:
Running
on
Zero
Running
on
Zero
init
Browse files- Dockerfile +20 -0
- app.py +229 -0
- checkpoints/ae_model_tf_2024-03-05_00-35-21.pth +3 -0
- checkpoints/autoencoder-epoch=09-train_loss=1.00.ckpt +3 -0
- checkpoints/autoencoder-epoch=29-train_loss=1.01.ckpt +3 -0
- checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt +3 -0
- datamodule.py +25 -0
- dataset.py +88 -0
- model.py +62 -0
- model_module.py +145 -0
- requirements.txt +9 -0
- utils.py +252 -0
Dockerfile
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ベースイメージとしてPython 3.9を使用
|
2 |
+
FROM python:3.9-slim
|
3 |
+
|
4 |
+
# 作業ディレクトリを設定
|
5 |
+
WORKDIR /app
|
6 |
+
|
7 |
+
# 必要なPythonライブラリをインストールするための依存ファイルをコピー
|
8 |
+
COPY requirements.txt /app/requirements.txt
|
9 |
+
|
10 |
+
# 必要なPythonパッケージをインストール
|
11 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
12 |
+
|
13 |
+
# アプリケーションコードをコンテナにコピー
|
14 |
+
COPY . /app
|
15 |
+
|
16 |
+
# ポート設定(Gradioのデフォルトポート7860)
|
17 |
+
EXPOSE 7860
|
18 |
+
|
19 |
+
# アプリケーションを起動
|
20 |
+
CMD ["python", "app.py"]
|
app.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from model_module import AutoencoderModule
|
8 |
+
from dataset import MyDataset, load_filenames
|
9 |
+
from utils import DistanceMapLogger
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
import base64
|
13 |
+
from io import BytesIO
|
14 |
+
|
15 |
+
# モデルとデータの読み込み
|
16 |
+
def load_model():
|
17 |
+
model_path = "checkpoints/ae_model_tf_2024-03-05_00-35-21.pth"
|
18 |
+
feature_dim = 32
|
19 |
+
model = AutoencoderModule(feature_dim=feature_dim)
|
20 |
+
state_dict = torch.load(model_path)
|
21 |
+
|
22 |
+
# state_dict のキーを修正
|
23 |
+
new_state_dict = {}
|
24 |
+
for key in state_dict:
|
25 |
+
new_key = "model." + key
|
26 |
+
new_state_dict[new_key] = state_dict[key]
|
27 |
+
model.load_state_dict(new_state_dict)
|
28 |
+
model.eval()
|
29 |
+
|
30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
model.to(device)
|
32 |
+
print("Model loaded successfully.")
|
33 |
+
return model, device
|
34 |
+
|
35 |
+
def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=32):
|
36 |
+
filenames = load_filenames(img_dir)
|
37 |
+
train_X = filenames[:1000]
|
38 |
+
train_ds = MyDataset(train_X, img_dir=img_dir, img_size=image_size)
|
39 |
+
|
40 |
+
train_loader = DataLoader(
|
41 |
+
train_ds,
|
42 |
+
batch_size=batch_size,
|
43 |
+
shuffle=True,
|
44 |
+
num_workers=0,
|
45 |
+
)
|
46 |
+
|
47 |
+
iterator = iter(train_loader)
|
48 |
+
x, _, _ = next(iterator)
|
49 |
+
x = x.to(device)
|
50 |
+
x = x[:,0].to(device)
|
51 |
+
print("Data loaded successfully.")
|
52 |
+
return x
|
53 |
+
|
54 |
+
model, device = load_model()
|
55 |
+
image_size = 112
|
56 |
+
batch_size = 32
|
57 |
+
x = load_data(device)
|
58 |
+
|
59 |
+
# アップロード画像の前処理
|
60 |
+
def preprocess_uploaded_image(uploaded_image, image_size):
|
61 |
+
uploaded_image = Image.fromarray(uploaded_image)
|
62 |
+
uploaded_image = uploaded_image.convert("RGB")
|
63 |
+
uploaded_image = uploaded_image.resize((image_size, image_size))
|
64 |
+
uploaded_image = np.array(uploaded_image).transpose(2, 0, 1) / 255.0
|
65 |
+
uploaded_image = torch.tensor(uploaded_image, dtype=torch.float32).unsqueeze(0).to(device)
|
66 |
+
return uploaded_image
|
67 |
+
|
68 |
+
# ヒートマップの生成関数
|
69 |
+
@spaces.GPU
|
70 |
+
def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
71 |
+
with torch.no_grad():
|
72 |
+
dec5, _ = model(x)
|
73 |
+
img = x
|
74 |
+
feature_map = dec5
|
75 |
+
batch_size = feature_map.size(0)
|
76 |
+
feature_dim = feature_map.size(1)
|
77 |
+
|
78 |
+
# アップロード画像の前処理
|
79 |
+
if uploaded_image is not None:
|
80 |
+
uploaded_image = preprocess_uploaded_image(uploaded_image, image_size)
|
81 |
+
target_feature_map, _ = model(uploaded_image)
|
82 |
+
img = torch.cat((img, uploaded_image))
|
83 |
+
feature_map = torch.cat((feature_map, target_feature_map))
|
84 |
+
batch_size += 1
|
85 |
+
else:
|
86 |
+
uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
|
87 |
+
|
88 |
+
target_num = batch_size - 1
|
89 |
+
|
90 |
+
x_coords = [x_coords] * batch_size
|
91 |
+
y_coords = [y_coords] * batch_size
|
92 |
+
|
93 |
+
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords]
|
94 |
+
vector = vectors[source_num]
|
95 |
+
|
96 |
+
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
|
97 |
+
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
|
98 |
+
|
99 |
+
norm_batch_distance_map = 1 / torch.cosh(20 * (batch_distance_map - batch_distance_map.min()) / (batch_distance_map.max() - batch_distance_map.min())) ** 2
|
100 |
+
|
101 |
+
source_map = norm_batch_distance_map[source_num]
|
102 |
+
target_map = norm_batch_distance_map[target_num]
|
103 |
+
|
104 |
+
alpha = 0.8
|
105 |
+
blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num] / norm_batch_distance_map[source_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
|
106 |
+
blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num] / norm_batch_distance_map[target_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
|
107 |
+
|
108 |
+
# Matplotlibでプロットして画像として保存
|
109 |
+
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
|
110 |
+
axs[0, 0].imshow(source_map.cpu(), cmap='hot')
|
111 |
+
axs[0, 0].set_title("Source Map")
|
112 |
+
axs[0, 1].imshow(target_map.cpu(), cmap='hot')
|
113 |
+
axs[0, 1].set_title("Target Map")
|
114 |
+
axs[1, 0].imshow(blended_source.permute(1, 2, 0).cpu())
|
115 |
+
axs[1, 0].set_title("Blended Source")
|
116 |
+
axs[1, 1].imshow(blended_target.permute(1, 2, 0).cpu())
|
117 |
+
axs[1, 1].set_title("Blended Target")
|
118 |
+
for ax in axs.flat:
|
119 |
+
ax.axis('off')
|
120 |
+
|
121 |
+
plt.tight_layout()
|
122 |
+
plt.close(fig)
|
123 |
+
return fig
|
124 |
+
|
125 |
+
def process_image(cropped_image_data):
|
126 |
+
# Base64からPILイメージに変換
|
127 |
+
header, base64_data = cropped_image_data.split(',', 1)
|
128 |
+
image_data = base64.b64decode(base64_data)
|
129 |
+
image = Image.open(BytesIO(image_data))
|
130 |
+
return image
|
131 |
+
|
132 |
+
# JavaScriptコード
|
133 |
+
scripts = """
|
134 |
+
async () => {
|
135 |
+
const script = document.createElement("script");
|
136 |
+
script.src = "https://cdnjs.cloudflare.com/ajax/libs/cropperjs/1.5.13/cropper.min.js";
|
137 |
+
document.head.appendChild(script);
|
138 |
+
|
139 |
+
const style = document.createElement("link");
|
140 |
+
style.rel = "stylesheet";
|
141 |
+
style.href = "https://cdnjs.cloudflare.com/ajax/libs/cropperjs/1.5.13/cropper.min.css";
|
142 |
+
document.head.appendChild(style);
|
143 |
+
|
144 |
+
script.onload = () => {
|
145 |
+
let cropper;
|
146 |
+
|
147 |
+
document.getElementById("input_file_button").onclick = function() {
|
148 |
+
document.querySelector("#input_file").click();
|
149 |
+
};
|
150 |
+
|
151 |
+
// GradioのFileコンポーネントから画像を読み込む
|
152 |
+
document.querySelector("#input_file").addEventListener("change", function(e) {
|
153 |
+
const files = e.target.files;
|
154 |
+
console.log(files);
|
155 |
+
if (files && files.length > 0) {
|
156 |
+
console.log("File selected");
|
157 |
+
document.querySelector("#crop_view").style.display = "block";
|
158 |
+
document.querySelector("#crop_button").style.display = "block";
|
159 |
+
const url = URL.createObjectURL(files[0]);
|
160 |
+
const crop_view = document.getElementById("crop_view");
|
161 |
+
crop_view.src = url;
|
162 |
+
|
163 |
+
if (cropper) {
|
164 |
+
cropper.destroy();
|
165 |
+
}
|
166 |
+
cropper = new Cropper(crop_view, {
|
167 |
+
aspectRatio: 1,
|
168 |
+
viewMode: 1,
|
169 |
+
});
|
170 |
+
}
|
171 |
+
});
|
172 |
+
|
173 |
+
// GradioボタンにJavaScriptの機能を追加
|
174 |
+
document.getElementById("crop_button").onclick = function() {
|
175 |
+
if (cropper) {
|
176 |
+
const canvas = cropper.getCroppedCanvas();
|
177 |
+
const croppedImageData = canvas.toDataURL();
|
178 |
+
|
179 |
+
// Gradioにクロップ画像を送信
|
180 |
+
const textbox = document.querySelector("#cropped_image_data textarea");
|
181 |
+
textbox.value = croppedImageData;
|
182 |
+
textbox.dispatchEvent(new Event("input", { bubbles: true }));
|
183 |
+
|
184 |
+
document.getElementById("crop_view").style.display = "none";
|
185 |
+
document.getElementById("crop_button").style.display = "none";
|
186 |
+
|
187 |
+
cropper.destroy();
|
188 |
+
}
|
189 |
+
};
|
190 |
+
document.getElementById("crop_view").style.display = "none";
|
191 |
+
document.getElementById("crop_button").style.display = "none";
|
192 |
+
};
|
193 |
+
}
|
194 |
+
"""
|
195 |
+
|
196 |
+
with gr.Blocks() as demo:
|
197 |
+
with gr.Row():
|
198 |
+
with gr.Column():
|
199 |
+
source_num = gr.Slider(0, batch_size - 1, step=1, label="Source Image Index")
|
200 |
+
x_coords = gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="X Coordinate")
|
201 |
+
y_coords = gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="Y Coordinate")
|
202 |
+
|
203 |
+
# GradioのFileコンポーネントでファイル選択ボタンを追加
|
204 |
+
gr.HTML('<input type="file" id="input_file" style="display:none;">')
|
205 |
+
input_file_button = gr.Button("画像を選択", elem_id="input_file_button")
|
206 |
+
# 画像を表示するためのHTML画像タグをGradioで表示
|
207 |
+
gr.HTML('<img id="crop_view" style="max-width:100%;">')
|
208 |
+
# Gradioのボタンコンポーネントを追加し、IDを付与
|
209 |
+
crop_button = gr.Button("クロップ", elem_id="crop_button", variant="primary")
|
210 |
+
# クロップされた画像データのテキストボックス(Base64データ)
|
211 |
+
cropped_image_data = gr.Textbox(visible=False, elem_id="cropped_image_data")
|
212 |
+
input_image = gr.Image(label="Cropped Image", interactive=False)
|
213 |
+
# cropped_image_dataが更新されたらprocess_imageを呼び出す
|
214 |
+
cropped_image_data.change(process_image, inputs=cropped_image_data, outputs=input_image)
|
215 |
+
|
216 |
+
with gr.Column():
|
217 |
+
output_plot = gr.Plot()
|
218 |
+
|
219 |
+
# Gradioインターフェースの代わり
|
220 |
+
source_num.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
|
221 |
+
x_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
|
222 |
+
y_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
|
223 |
+
input_image.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
|
224 |
+
|
225 |
+
# JavaScriptコードをロード
|
226 |
+
demo.load(None, None, None, js=scripts)
|
227 |
+
|
228 |
+
demo.launch()
|
229 |
+
|
checkpoints/ae_model_tf_2024-03-05_00-35-21.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:77b020cb89ad2ccf7a7bf654d86fb975793cbe168bf73cd011e93cf22f63204c
|
3 |
+
size 2629576
|
checkpoints/autoencoder-epoch=09-train_loss=1.00.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:af08cd5fbb1c832824b7466be4da59d2a40e6e9eef864097514a2806a24bb92b
|
3 |
+
size 3046514
|
checkpoints/autoencoder-epoch=29-train_loss=1.01.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2dc5dcf07a6a66cd4f0773af5fff29903d5fb9fa340221cd59083462e1ae77b7
|
3 |
+
size 3046959
|
checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a4f218b652ba63e9891b73efca5faffaabe4692f1b78755860ff46b113d09ecd
|
3 |
+
size 3046959
|
datamodule.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from dataset import MyDataset, load_filenames # dataset.pyに基づく
|
4 |
+
|
5 |
+
class DataModule(pl.LightningDataModule):
|
6 |
+
def __init__(self, img_dir, batch_size, img_size=112, num_workers=0):
|
7 |
+
super().__init__()
|
8 |
+
self.img_dir = img_dir
|
9 |
+
self.batch_size = batch_size
|
10 |
+
self.img_size = img_size
|
11 |
+
self.num_workers = num_workers
|
12 |
+
self.file_num = 1000 # or 3400
|
13 |
+
|
14 |
+
def setup(self, stage=None):
|
15 |
+
filenames = load_filenames(self.img_dir)
|
16 |
+
self.train_dataset = MyDataset(filenames[:self.file_num], img_dir=self.img_dir, img_size=self.img_size)
|
17 |
+
|
18 |
+
def train_dataloader(self):
|
19 |
+
return DataLoader(
|
20 |
+
self.train_dataset,
|
21 |
+
batch_size=self.batch_size,
|
22 |
+
shuffle=True,
|
23 |
+
num_workers=self.num_workers,
|
24 |
+
persistent_workers=True
|
25 |
+
)
|
dataset.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
from torchvision import transforms
|
4 |
+
import random
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
|
8 |
+
from utils import RandomAffineAndRetMat
|
9 |
+
|
10 |
+
def load_filenames(data_dir):
|
11 |
+
# label_data = pd.read_json(INPUT_DIR+'DataList.json')
|
12 |
+
# label_data = label_data.sort_index()
|
13 |
+
# tmp_points = []
|
14 |
+
# filenames = []
|
15 |
+
|
16 |
+
# for o in tqdm(label_data.data[0:1000]):
|
17 |
+
# filenames.append(o['filename'])
|
18 |
+
# a = o['filename']
|
19 |
+
|
20 |
+
# tmps = []
|
21 |
+
# for i in range(60):
|
22 |
+
# tmps.append(o['points'][str(i)]['x'])
|
23 |
+
# tmps.append(o['points'][str(i)]['y'])
|
24 |
+
# tmp_points.append(tmps) # datanum
|
25 |
+
|
26 |
+
# filenames = pd.Series(filenames)
|
27 |
+
# filenames = [str(i).zfill(4)+'.jpg' for i in range(3400)]
|
28 |
+
# df_points = pd.DataFrame(tmp_points)
|
29 |
+
|
30 |
+
# load from data_dir
|
31 |
+
# 画像の拡張子のみ
|
32 |
+
img_exts = ['.jpg', '.jpeg', '.png', '.bmp', '.ppm', '.pgm', '.tif', '.tiff']
|
33 |
+
filenames = [f for f in os.listdir(data_dir) if os.path.splitext(f)[1].lower() in img_exts]
|
34 |
+
|
35 |
+
return filenames
|
36 |
+
|
37 |
+
|
38 |
+
class MyDataset:
|
39 |
+
def __init__(self, X, valid=False, img_dir='resources/trainB/', img_size=256):
|
40 |
+
self.X = X
|
41 |
+
self.valid = valid
|
42 |
+
self.img_dir = img_dir
|
43 |
+
self.img_size = img_size
|
44 |
+
|
45 |
+
def __len__(self):
|
46 |
+
return len(self.X)
|
47 |
+
|
48 |
+
def __getitem__(self, index):
|
49 |
+
# 画像を読み込んでトランスフォームを適用
|
50 |
+
f = self.img_dir + self.X[index]
|
51 |
+
original_X = Image.open(f)
|
52 |
+
trans = [
|
53 |
+
transforms.ToTensor(),
|
54 |
+
# transforms.Normalize(mean=means, std=stds),
|
55 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
56 |
+
|
57 |
+
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.15),
|
58 |
+
transforms.RandomGrayscale(0.3),
|
59 |
+
]
|
60 |
+
transform = transforms.Compose(trans)
|
61 |
+
xlist = []
|
62 |
+
matlist = []
|
63 |
+
is_flip = random.randint(0, 1) # 同じ画像はフリップ
|
64 |
+
for i in range(2):
|
65 |
+
af = RandomAffineAndRetMat(
|
66 |
+
degrees=[-30, 30],
|
67 |
+
translate=(0.1, 0.1), scale=(0.8, 1.2),
|
68 |
+
# fill=(random.random(), random.random(), random.random()),
|
69 |
+
fill=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
|
70 |
+
shear=[-10, 10],
|
71 |
+
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
|
72 |
+
)
|
73 |
+
X, affine_matrix = af(transforms.Resize(self.img_size)(original_X))
|
74 |
+
|
75 |
+
# randomflip
|
76 |
+
if is_flip == 1:
|
77 |
+
X = transforms.RandomHorizontalFlip(1.)(X)
|
78 |
+
flip_matrix = torch.tensor([[-1., 0., 0.],
|
79 |
+
[0., 1., 0.],
|
80 |
+
[0., 0., 1.]])
|
81 |
+
affine_matrix = torch.matmul(flip_matrix, affine_matrix)
|
82 |
+
|
83 |
+
xlist.append(transform(X))
|
84 |
+
matlist.append(affine_matrix)
|
85 |
+
|
86 |
+
X = torch.stack(xlist)
|
87 |
+
mat = torch.stack(matlist)
|
88 |
+
return X, mat, f
|
model.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
class ConvBlock(nn.Module):
|
5 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
|
6 |
+
super(ConvBlock, self).__init__()
|
7 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
8 |
+
self.batchnorm = nn.BatchNorm2d(out_channels)
|
9 |
+
self.relu = nn.ReLU()
|
10 |
+
|
11 |
+
def forward(self, x):
|
12 |
+
return self.relu(self.batchnorm(self.conv(x)))
|
13 |
+
|
14 |
+
class DeconvBlock(nn.Module):
|
15 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding):
|
16 |
+
super(DeconvBlock, self).__init__()
|
17 |
+
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding)
|
18 |
+
self.batchnorm = nn.BatchNorm2d(out_channels)
|
19 |
+
self.relu = nn.ReLU()
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
return self.relu(self.batchnorm(self.deconv(x)))
|
23 |
+
|
24 |
+
class Autoencoder(nn.Module):
|
25 |
+
def __init__(self, feature_dim=32):
|
26 |
+
super(Autoencoder, self).__init__()
|
27 |
+
self.feature_dim = feature_dim
|
28 |
+
|
29 |
+
# エンコーダ
|
30 |
+
self.enc1 = ConvBlock(3, 16, 10, 1, 0)
|
31 |
+
self.enc2 = ConvBlock(16, 32, 10, 1, 0)
|
32 |
+
self.enc3 = ConvBlock(32, 64, 2, 2, 0)
|
33 |
+
self.enc4 = ConvBlock(64, 128, 2, 2, 0)
|
34 |
+
self.enc5 = ConvBlock(128, 256, 2, 2, 0)
|
35 |
+
|
36 |
+
# デコーダ
|
37 |
+
self.dec1 = DeconvBlock(256, 128, 2, 2, 0, 1)
|
38 |
+
self.dec2 = DeconvBlock(256, 64, 2, 2, 0, 1) # 128 + 128
|
39 |
+
self.dec3 = DeconvBlock(128, 32, 2, 2, 0, 0) # 64 + 64
|
40 |
+
self.dec4 = DeconvBlock(64, 16, 10, 1, 0, 0) # 32 + 32
|
41 |
+
self.dec5 = DeconvBlock(32, self.feature_dim, 10, 1, 0, 0)
|
42 |
+
self.dec6 = nn.Conv2d(self.feature_dim, 32, 1, 1, 0)
|
43 |
+
self.dec7 = nn.Conv2d(32, 3, 1, 1, 0)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
# エンコーダ
|
47 |
+
enc1 = self.enc1(x)
|
48 |
+
enc2 = self.enc2(enc1)
|
49 |
+
enc3 = self.enc3(enc2)
|
50 |
+
enc4 = self.enc4(enc3)
|
51 |
+
enc5 = self.enc5(enc4)
|
52 |
+
|
53 |
+
# デコーダ
|
54 |
+
dec1 = self.dec1(enc5)
|
55 |
+
dec2 = self.dec2(torch.cat((dec1, enc4), 1))
|
56 |
+
dec3 = self.dec3(torch.cat((dec2, enc3), 1))
|
57 |
+
dec4 = self.dec4(torch.cat((dec3, enc2), 1))
|
58 |
+
dec5 = self.dec5(torch.cat((dec4, enc1), 1))
|
59 |
+
dec6 = self.dec6(dec5)
|
60 |
+
dec7 = self.dec7(dec6)
|
61 |
+
|
62 |
+
return dec5, dec7
|
model_module.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.optim import SGD
|
5 |
+
from torchvision.utils import save_image
|
6 |
+
import os
|
7 |
+
from utils import TripletLossBatch, pairwise_distance_squared, GetTransformedCoords, DistanceMapLogger
|
8 |
+
from model import Autoencoder
|
9 |
+
|
10 |
+
class AutoencoderModule(pl.LightningModule):
|
11 |
+
def __init__(self, feature_dim=64, learning_rate=0.1, lambda_c=0.97, initial_margin=1.0, initial_threshold=2.0, save_interval=100, output_dir="output_images"):
|
12 |
+
super(AutoencoderModule, self).__init__()
|
13 |
+
self.feature_dim = feature_dim
|
14 |
+
self.learning_rate = learning_rate
|
15 |
+
self.lambda_c = lambda_c
|
16 |
+
self.margin_img = initial_margin
|
17 |
+
self.margin_img_init = initial_margin
|
18 |
+
self.threshold = initial_threshold
|
19 |
+
self.model = Autoencoder(self.feature_dim)
|
20 |
+
self.criterion = nn.MSELoss()
|
21 |
+
self.triplet_loss = TripletLossBatch()
|
22 |
+
self.losses = []
|
23 |
+
self.save_interval = save_interval # バッチごとの出力間隔
|
24 |
+
self.output_dir = output_dir
|
25 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return self.model(x)
|
29 |
+
|
30 |
+
def training_step(self, batch, batch_idx):
|
31 |
+
img, mat, _ = batch
|
32 |
+
batch_size, _, _, size, size = img.shape
|
33 |
+
img = img.view(batch_size*2, 3, size, size)
|
34 |
+
mat = mat.view(batch_size*2, 3, 3)
|
35 |
+
|
36 |
+
dec5_output, output = self.model(img)
|
37 |
+
mse_loss = self.criterion(output, img)
|
38 |
+
|
39 |
+
# 画像内方向の処理
|
40 |
+
num_anchor_sets = 2**12
|
41 |
+
trip_loss = 0
|
42 |
+
std_list = [2.5*1.025**self.current_epoch, 5*1.025**self.current_epoch]
|
43 |
+
for c in std_list:
|
44 |
+
std = size / c
|
45 |
+
anchors = torch.randint(0, size, (batch_size*2, num_anchor_sets, 1, 2))
|
46 |
+
coords = anchors + torch.normal(0, std, (batch_size*2, num_anchor_sets, 2, 2)).long()
|
47 |
+
valid_coords_idx = (((coords >= 0) & (coords < size)).sum(3) == 2).sum(2) != 2
|
48 |
+
coords[valid_coords_idx] = 0
|
49 |
+
anchors[valid_coords_idx] = 0
|
50 |
+
|
51 |
+
# 最も近い座標の選択
|
52 |
+
d = pairwise_distance_squared(anchors.float(), coords.float())
|
53 |
+
idx = torch.argmin(d, dim=2)
|
54 |
+
anchors, positives, negatives = self._get_triplet_coordinates(anchors, coords, idx)
|
55 |
+
|
56 |
+
# dec5_outputから特徴ベクトルを抽出
|
57 |
+
anchor_vectors, positive_vectors, negative_vectors = self._extract_feature_vectors(dec5_output, batch_size, anchors, positives, negatives)
|
58 |
+
trip_loss += self.triplet_loss(anchor_vectors, positive_vectors, negative_vectors, self.margin_img)
|
59 |
+
|
60 |
+
trip_loss /= len(std_list)
|
61 |
+
self.margin_img = self.margin_img_init + self.margin_img - trip_loss.detach()
|
62 |
+
|
63 |
+
# 変形の学習
|
64 |
+
num_samples = 2**20
|
65 |
+
tf_loss = self._compute_transformation_loss(dec5_output, mat, batch_size, size, num_samples)
|
66 |
+
|
67 |
+
# バッチ方向の処理
|
68 |
+
bat_dist_loss = self._compute_batch_direction_loss(dec5_output, batch_size, size)
|
69 |
+
|
70 |
+
# 合計損失
|
71 |
+
loss = mse_loss + trip_loss + 0.001 * bat_dist_loss + (0.001 * 1.**self.current_epoch) * tf_loss
|
72 |
+
self.log("train_loss", loss)
|
73 |
+
|
74 |
+
# VRAM管理
|
75 |
+
del img, output
|
76 |
+
torch.cuda.empty_cache()
|
77 |
+
|
78 |
+
return loss
|
79 |
+
|
80 |
+
|
81 |
+
def _get_triplet_coordinates(self, anchors, coords, idx):
|
82 |
+
anchors = anchors.squeeze(2)
|
83 |
+
positives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], idx[:, :, None]].squeeze(2)
|
84 |
+
negatives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], (1 - idx)[:, :, None]].squeeze(2)
|
85 |
+
return anchors, positives, negatives
|
86 |
+
|
87 |
+
def _extract_feature_vectors(self, dec5_output, batch_size, anchors, positives, negatives):
|
88 |
+
y_anchors = anchors[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
|
89 |
+
x_anchors = anchors[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
|
90 |
+
y_positives = positives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
|
91 |
+
x_positives = positives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
|
92 |
+
y_negatives = negatives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
|
93 |
+
x_negatives = negatives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
|
94 |
+
|
95 |
+
anchor_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_anchors, x_anchors]
|
96 |
+
positive_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_positives, x_positives]
|
97 |
+
negative_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_negatives, x_negatives]
|
98 |
+
return anchor_vectors, positive_vectors, negative_vectors
|
99 |
+
|
100 |
+
def _compute_transformation_loss(self, dec5_output, mat, batch_size, size, num_samples=2**12):
|
101 |
+
anchor_indices = torch.randint(batch_size, (num_samples, 1), device=self.device).repeat(1, 2).reshape(num_samples*2)
|
102 |
+
coords_x = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1)
|
103 |
+
coords_y = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1)
|
104 |
+
anchor_coords = torch.cat((coords_x, coords_y), 1)
|
105 |
+
anchor_mat = mat[anchor_indices]
|
106 |
+
tf_anchor_coords = GetTransformedCoords(anchor_mat, [size/2, size/2])(anchor_coords)
|
107 |
+
|
108 |
+
anchor_vectors = torch.zeros([num_samples*2, self.feature_dim], device=self.device)
|
109 |
+
inner_idx_flat = ((0 <= tf_anchor_coords[:,0]) & (tf_anchor_coords[:,0] < size)) & ((0 <= tf_anchor_coords[:,1]) & (tf_anchor_coords[:,1] < size))
|
110 |
+
anchor_vectors[inner_idx_flat] = dec5_output[anchor_indices[inner_idx_flat], :, tf_anchor_coords[inner_idx_flat, 0], tf_anchor_coords[inner_idx_flat, 1]]
|
111 |
+
|
112 |
+
inner_idx_and = inner_idx_flat.view(num_samples, 2).t()[0] & inner_idx_flat.view(num_samples, 2).t()[1]
|
113 |
+
anchor_vectors = anchor_vectors.view(num_samples, 2, self.feature_dim)[inner_idx_and]
|
114 |
+
return pairwise_distance_squared(anchor_vectors[:,0], anchor_vectors[:,1]).mean()
|
115 |
+
|
116 |
+
def _compute_batch_direction_loss(self, dec5_output, batch_size, size):
|
117 |
+
N = 2**12
|
118 |
+
anchor_indices = torch.randint(0, batch_size, (N,)) * 2 + torch.randint(0, 2, (N,))
|
119 |
+
anchor_coords = torch.randint(0, size, (N, 2))
|
120 |
+
other_indices = torch.randint(0, batch_size-1, (N, 2)) * 2 + torch.randint(0, 2, (N, 2))
|
121 |
+
other_indices += (other_indices >= anchor_indices.unsqueeze(1)).long() * 2
|
122 |
+
other_coords = torch.randint(0, size, (N, 2, 2))
|
123 |
+
|
124 |
+
anchor_vectors = dec5_output[anchor_indices, :, anchor_coords[:, 0], anchor_coords[:, 1]]
|
125 |
+
other_vectors = dec5_output[other_indices, :, other_coords[:, :, 0], other_coords[:, :, 1]]
|
126 |
+
distances = pairwise_distance_squared(anchor_vectors.unsqueeze(1), other_vectors)
|
127 |
+
return distances[distances < self.threshold].sum() / ((distances < self.threshold).sum() + 1e-10)
|
128 |
+
|
129 |
+
def configure_optimizers(self):
|
130 |
+
optimizer = SGD(self.parameters(), lr=self.learning_rate)
|
131 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch)
|
132 |
+
return [optimizer], [scheduler]
|
133 |
+
|
134 |
+
# def save_intermediate_image(self, output, epoch):
|
135 |
+
# save_image(output[:4], os.path.join(self.output_dir, f"epoch_{epoch}_output.png"), nrow=1)
|
136 |
+
# print(f"Saved intermediate image at epoch {epoch}")
|
137 |
+
|
138 |
+
# def distance_map(self, _input, feature_map, epoch, x_coords=None, y_coords=None):
|
139 |
+
# save_path = os.path.join(self.output_dir, f"epoch_{epoch}_distance_map.png")
|
140 |
+
# DistanceMapLogger()(_input, feature_map, save_path, x_coords, y_coords)
|
141 |
+
|
142 |
+
def configure_optimizers(self):
|
143 |
+
optimizer = SGD(self.parameters(), lr=self.learning_rate)
|
144 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch)
|
145 |
+
return [optimizer], [scheduler]
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu121
|
2 |
+
torch==2.2.0
|
3 |
+
torchvision==0.17.0
|
4 |
+
torchaudio==2.2.0
|
5 |
+
matplotlib==3.9.2
|
6 |
+
numpy==1.26.4
|
7 |
+
pytorch-lightning==2.4.0
|
8 |
+
scikit-learn==1.0.2
|
9 |
+
gradio==5.5.0
|
utils.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import Tensor, nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision
|
5 |
+
from torchvision import transforms
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from sklearn.decomposition import PCA
|
10 |
+
|
11 |
+
class RandomAffineAndRetMat(torch.nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
degrees,
|
15 |
+
translate=None,
|
16 |
+
scale=None,
|
17 |
+
shear=None,
|
18 |
+
interpolation=torchvision.transforms.InterpolationMode.NEAREST,
|
19 |
+
fill=0,
|
20 |
+
center=None,
|
21 |
+
):
|
22 |
+
super().__init__()
|
23 |
+
self.degrees = degrees
|
24 |
+
self.translate = translate
|
25 |
+
self.scale = scale
|
26 |
+
self.shear = shear
|
27 |
+
self.interpolation = interpolation
|
28 |
+
self.fill = fill
|
29 |
+
self.center = center
|
30 |
+
|
31 |
+
def forward(self, img):
|
32 |
+
"""
|
33 |
+
img (PIL Image or Tensor): Image to be transformed.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
PIL Image or Tensor: Affine transformed image.
|
37 |
+
"""
|
38 |
+
fill = self.fill
|
39 |
+
if isinstance(img, Tensor):
|
40 |
+
if isinstance(fill, (int, float)):
|
41 |
+
fill = [float(fill)] * transforms.functional.get_image_num_channels(img)
|
42 |
+
else:
|
43 |
+
fill = [float(f) for f in fill]
|
44 |
+
|
45 |
+
img_size = transforms.functional.get_image_size(img)
|
46 |
+
|
47 |
+
ret = transforms.RandomAffine.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
|
48 |
+
transformed_image = transforms.functional.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
|
49 |
+
|
50 |
+
affine_matrix = self.get_affine_matrix_from_params(ret)
|
51 |
+
|
52 |
+
return transformed_image, affine_matrix
|
53 |
+
|
54 |
+
def get_affine_matrix_from_params(self, params):
|
55 |
+
degrees, translate, scale, shear = params
|
56 |
+
degrees = torch.tensor(degrees)
|
57 |
+
shear = torch.tensor(shear)
|
58 |
+
|
59 |
+
# パラメータを変換行列に変換
|
60 |
+
rotation_matrix = torch.tensor([[torch.cos(torch.deg2rad(degrees)), -torch.sin(torch.deg2rad(degrees)), 0],
|
61 |
+
[torch.sin(torch.deg2rad(degrees)), torch.cos(torch.deg2rad(degrees)), 0],
|
62 |
+
[0, 0, 1]])
|
63 |
+
|
64 |
+
translation_matrix = torch.tensor([[1, 0, translate[0]],
|
65 |
+
[0, 1, translate[1]],
|
66 |
+
[0, 0, 1]]).to(torch.float32)
|
67 |
+
|
68 |
+
scaling_matrix = torch.tensor([[scale, 0, 0],
|
69 |
+
[0, scale, 0],
|
70 |
+
[0, 0, 1]])
|
71 |
+
|
72 |
+
shearing_matrix = torch.tensor([[1, -torch.tan(torch.deg2rad(shear[0])), 0],
|
73 |
+
[-torch.tan(torch.deg2rad(shear[1])), 1, 0],
|
74 |
+
[0, 0, 1]])
|
75 |
+
|
76 |
+
# 変換行列を合成
|
77 |
+
affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix)
|
78 |
+
|
79 |
+
return affine_matrix
|
80 |
+
|
81 |
+
class GetTransformedCoords(nn.Module):
|
82 |
+
def __init__(self, affine_matrix, center):
|
83 |
+
super().__init__()
|
84 |
+
self.affine_matrix = affine_matrix
|
85 |
+
self.center = center
|
86 |
+
|
87 |
+
def forward(self, _coords):
|
88 |
+
# coords: like tensor([[43, 26], [44, 27], [45, 28]])
|
89 |
+
center_x, center_y = self.center
|
90 |
+
# 元の座標を中心原点にシフト
|
91 |
+
coords = _coords.clone()
|
92 |
+
coords[:, 0] -= center_x
|
93 |
+
coords[:, 1] -= center_y
|
94 |
+
|
95 |
+
# 各バッチに対して変換を行う
|
96 |
+
homogeneous_coordinates = torch.cat([coords, torch.ones(coords.shape[0], 1, dtype=torch.float32, device=coords.device)], dim=1)
|
97 |
+
transformed_coordinates = torch.bmm(self.affine_matrix, homogeneous_coordinates.unsqueeze(-1)).squeeze(-1)
|
98 |
+
|
99 |
+
# 画像の範囲内に収める
|
100 |
+
# transformed_x = max(0, min(width - 1, transformed_coordinates[:, 0]))
|
101 |
+
# transformed_y = max(0, min(height - 1, transformed_coordinates[:, 1]))
|
102 |
+
transformed_x = transformed_coordinates[:, 0]
|
103 |
+
transformed_y = transformed_coordinates[:, 1]
|
104 |
+
|
105 |
+
transformed_x += center_x
|
106 |
+
transformed_y += center_y
|
107 |
+
return torch.stack([transformed_x, transformed_y]).t().to(torch.long)
|
108 |
+
|
109 |
+
# ルートを取らないpairwise_distanceのバージョン
|
110 |
+
def pairwise_distance_squared(a, b):
|
111 |
+
return torch.sum((a - b) ** 2, dim=-1)
|
112 |
+
|
113 |
+
def cosine_similarity(a, b):
|
114 |
+
# ベクトルaとbの内積を計算
|
115 |
+
dot_product = torch.matmul(a, b)
|
116 |
+
# ベクトルaとbのノルム(大きさ)を計算
|
117 |
+
norm_a = torch.sqrt(torch.sum(a ** 2, dim=-1))
|
118 |
+
norm_b = torch.sqrt(torch.sum(b ** 2, dim=-1))
|
119 |
+
# コサイン類似度を計算(内積をノルムの積で割る)
|
120 |
+
return dot_product / (norm_a * norm_b)
|
121 |
+
|
122 |
+
def batch_cosine_similarity(a, b):
|
123 |
+
# ベクトルaとbの内積を計算
|
124 |
+
dot_product = torch.einsum('bnd,bnd->bn', a, b)
|
125 |
+
# ベクトルaとbのノルム(大きさ)を計算
|
126 |
+
norm_a = torch.sqrt(torch.sum(a ** 2, dim=-1))
|
127 |
+
norm_b = torch.sqrt(torch.sum(b ** 2, dim=-1))
|
128 |
+
# コサイン類似度を計算(内積をノルムの積で割る)
|
129 |
+
return dot_product / (norm_a * norm_b)
|
130 |
+
|
131 |
+
class TripletLossBatch(nn.Module):
|
132 |
+
def __init__(self):
|
133 |
+
super(TripletLossBatch, self).__init__()
|
134 |
+
|
135 |
+
def forward(self, anchor, positive, negative, margin=1.0):
|
136 |
+
distance_positive = F.pairwise_distance(anchor, positive, p=2)
|
137 |
+
distance_negative = F.pairwise_distance(anchor, negative, p=2)
|
138 |
+
losses = torch.relu(distance_positive - distance_negative + margin)
|
139 |
+
return losses.mean()
|
140 |
+
|
141 |
+
class TripletLossCosineSimilarity(nn.Module):
|
142 |
+
def __init__(self):
|
143 |
+
super(TripletLossCosineSimilarity, self).__init__()
|
144 |
+
|
145 |
+
def forward(self, anchor, positive, negative, margin=1.0):
|
146 |
+
distance_positive = 1 - batch_cosine_similarity(anchor, positive)
|
147 |
+
distance_negative = 1 - batch_cosine_similarity(anchor, negative)
|
148 |
+
losses = torch.relu(distance_positive - distance_negative + margin)
|
149 |
+
return losses.mean()
|
150 |
+
|
151 |
+
def imsave(img):
|
152 |
+
img = torchvision.utils.make_grid(img)
|
153 |
+
img = img / 2 + 0.5
|
154 |
+
npimg = img.detach().cpu().numpy()
|
155 |
+
# plt.imshow(np.transpose(npimg, (1, 2, 0)))
|
156 |
+
# plt.show()
|
157 |
+
# save image
|
158 |
+
npimg = np.transpose(npimg, (1, 2, 0))
|
159 |
+
npimg = npimg * 255
|
160 |
+
npimg = npimg.astype(np.uint8)
|
161 |
+
Image.fromarray(npimg).save('sample.png')
|
162 |
+
|
163 |
+
def norm_img(img):
|
164 |
+
return (img-img.min())/(img.max()-img.min())
|
165 |
+
|
166 |
+
def norm_img2(img):
|
167 |
+
return (img-img.min())/(img.max()-img.min())*255
|
168 |
+
|
169 |
+
class DistanceMapLogger:
|
170 |
+
def __call__(self, img, feature_map, save_path, x_coords=None, y_coords=None):
|
171 |
+
device = feature_map.device
|
172 |
+
batch_size = feature_map.size(0)
|
173 |
+
feature_dim = feature_map.size(1)
|
174 |
+
image_size = feature_map.size(2)
|
175 |
+
|
176 |
+
if x_coords is None:
|
177 |
+
x_coords = [69]*batch_size
|
178 |
+
if y_coords is None:
|
179 |
+
y_coords = [42]*batch_size
|
180 |
+
|
181 |
+
# PCAで3次元のマップを抽出
|
182 |
+
pca = PCA(n_components=3)
|
183 |
+
pca_result = pca.fit_transform(feature_map.permute(0,2,3,1).reshape(-1,feature_dim).detach().cpu().numpy()) # PCA を実行
|
184 |
+
reshaped_pca_result = pca_result.reshape(batch_size,image_size,image_size,3) # 3次元に変換(元は1次元)
|
185 |
+
|
186 |
+
sample_num = 0
|
187 |
+
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
|
188 |
+
vector = vectors[sample_num]
|
189 |
+
|
190 |
+
# バッチ内の各特徴マップに対して内積を計算
|
191 |
+
# feature_mapの次元を並べ替えてバッチと高さ・幅を平坦化
|
192 |
+
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
|
193 |
+
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
|
194 |
+
# batch_distance_map = F.cosine_similarity(reshaped_feature_map, vector.unsqueeze(0).unsqueeze(0).expand(65,size*size,32), dim=2).permute(1, 0).reshape(feature_map.size(0), size, size)
|
195 |
+
norm_batch_distance_map = 1/torch.cosh( 20*(batch_distance_map-batch_distance_map.min())/(batch_distance_map.max()-batch_distance_map.min()) )**2
|
196 |
+
# norm_batch_distance_map[:,0,0] = 0.001
|
197 |
+
# 可視化と保存
|
198 |
+
fig, axes = plt.subplots(5, 4, figsize=(20, 25))
|
199 |
+
for ax in axes.flatten():
|
200 |
+
ax.axis('off')
|
201 |
+
# 余白をなくす
|
202 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
203 |
+
# 外の余白もなくす
|
204 |
+
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
|
205 |
+
|
206 |
+
# 距離マップの可視化
|
207 |
+
for i in range(5):
|
208 |
+
axes[i, 0].imshow(norm_batch_distance_map[i].detach().cpu(), cmap='hot')
|
209 |
+
if i == sample_num:
|
210 |
+
axes[i, 0].scatter(x_coords[i], y_coords[i], c='b', s=7)
|
211 |
+
|
212 |
+
distance_map = torch.cat(((norm_batch_distance_map[i]/norm_batch_distance_map[i].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
|
213 |
+
alpha = 0.9 # Transparency factor for the heatmap overlay
|
214 |
+
blended_tensor = (1 - alpha) * img[i] + alpha * distance_map
|
215 |
+
axes[i, 1].imshow(norm_img(blended_tensor.permute(1,2,0).detach().cpu()))
|
216 |
+
|
217 |
+
axes[i, 2].imshow(norm_img(img[i].permute(1,2,0).detach().cpu()))
|
218 |
+
|
219 |
+
axes[i, 3].imshow(norm_img(reshaped_pca_result[i]))
|
220 |
+
|
221 |
+
plt.savefig(save_path)
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
def get_heatmaps(self, img, feature_map, source_num=0, target_num=1, x_coords=69, y_coords=42):
|
226 |
+
device = feature_map.device
|
227 |
+
batch_size = feature_map.size(0)
|
228 |
+
feature_dim = feature_map.size(1)
|
229 |
+
image_size = feature_map.size(2)
|
230 |
+
|
231 |
+
x_coords = [x_coords]*batch_size
|
232 |
+
y_coords = [y_coords]*batch_size
|
233 |
+
|
234 |
+
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
|
235 |
+
vector = vectors[source_num]
|
236 |
+
|
237 |
+
# バッチ内の各特徴マップに対して内積を計算
|
238 |
+
# feature_mapの次元を並べ替えてバッチと高さ・幅を平坦化
|
239 |
+
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
|
240 |
+
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
|
241 |
+
# batch_distance_map = F.cosine_similarity(reshaped_feature_map, vector.unsqueeze(0).unsqueeze(0).expand(65,size*size,32), dim=2).permute(1, 0).reshape(feature_map.size(0), size, size)
|
242 |
+
norm_batch_distance_map = 1/torch.cosh( 20*(batch_distance_map-batch_distance_map.min())/(batch_distance_map.max()-batch_distance_map.min()) )**2
|
243 |
+
# norm_batch_distance_map[:,0,0] = 0.001
|
244 |
+
|
245 |
+
source_map = norm_batch_distance_map[source_num]
|
246 |
+
target_map = norm_batch_distance_map[target_num]
|
247 |
+
|
248 |
+
alpha = 0.9
|
249 |
+
blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num]/norm_batch_distance_map[source_num].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
|
250 |
+
blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num]/norm_batch_distance_map[target_num].max()).unsqueeze(0),torch.zeros(2,image_size,image_size,device=device)))
|
251 |
+
|
252 |
+
return source_map, target_map, blended_source, blended_target
|