Spaces:
Sleeping
Sleeping
keypoints
Browse files- app.py +128 -86
- dataset.py +78 -1
- resources/DataList.json +0 -0
- utils.py +90 -0
app.py
CHANGED
@@ -1,74 +1,92 @@
|
|
1 |
import gradio as gr
|
2 |
-
import spaces
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
from torch.utils.data import DataLoader
|
6 |
import matplotlib.pyplot as plt
|
7 |
from model_module import AutoencoderModule
|
8 |
-
from dataset import MyDataset, load_filenames
|
9 |
import numpy as np
|
10 |
from PIL import Image
|
11 |
import base64
|
12 |
from io import BytesIO
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# モデルとデータの読み込み
|
15 |
-
def load_model():
|
16 |
-
model_path = "checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt"
|
17 |
-
feature_dim = 64
|
18 |
model = AutoencoderModule(feature_dim=feature_dim)
|
19 |
state_dict = torch.load(model_path)
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
|
29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
model.to(device)
|
31 |
print("Model loaded successfully.")
|
32 |
return model, device
|
33 |
|
34 |
-
def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=
|
35 |
filenames = load_filenames(img_dir)
|
36 |
train_X = filenames[:1000]
|
|
|
37 |
train_ds = MyDataset(train_X, img_dir=img_dir, img_size=image_size)
|
38 |
|
39 |
-
train_loader = DataLoader(
|
40 |
-
train_ds,
|
41 |
-
batch_size=batch_size,
|
42 |
-
shuffle=True,
|
43 |
-
num_workers=0,
|
44 |
-
)
|
45 |
|
46 |
iterator = iter(train_loader)
|
47 |
x, _, _ = next(iterator)
|
48 |
x = x.to(device)
|
49 |
x = x[:,0].to(device)
|
|
|
50 |
print("Data loaded successfully.")
|
51 |
return x
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
68 |
|
69 |
# ヒートマップの生成関数
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
if type(uploaded_image) == str:
|
73 |
uploaded_image = Image.open(uploaded_image)
|
74 |
if type(source_num) == str:
|
@@ -77,60 +95,68 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
|
77 |
x_coords = int(x_coords)
|
78 |
if type(y_coords) == str:
|
79 |
y_coords = int(y_coords)
|
80 |
-
|
81 |
-
with torch.no_grad():
|
82 |
-
dec5, _ = model(x)
|
83 |
-
img = x
|
84 |
-
feature_map = dec5
|
85 |
-
batch_size = feature_map.size(0)
|
86 |
-
feature_dim = feature_map.size(1)
|
87 |
-
|
88 |
-
# アップロード画像の前処理
|
89 |
-
if uploaded_image is not None:
|
90 |
-
uploaded_image = preprocess_uploaded_image(uploaded_image['composite'], image_size)
|
91 |
-
target_feature_map, _ = model(uploaded_image)
|
92 |
-
img = torch.cat((img, uploaded_image))
|
93 |
-
feature_map = torch.cat((feature_map, target_feature_map))
|
94 |
-
batch_size += 1
|
95 |
-
else:
|
96 |
-
uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
|
97 |
-
|
98 |
-
target_num = batch_size - 1
|
99 |
|
100 |
-
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
-
norm_batch_distance_map = 1 / torch.cosh(20 * (batch_distance_map - batch_distance_map.min()) / (batch_distance_map.max() - batch_distance_map.min())) ** 2
|
110 |
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
113 |
|
114 |
-
alpha = 0.7
|
115 |
-
blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num] / norm_batch_distance_map[source_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
|
116 |
-
blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num] / norm_batch_distance_map[target_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
|
117 |
-
|
118 |
-
# Matplotlibでプロットして画像として保存
|
119 |
-
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
|
120 |
-
axs[0, 0].imshow(source_map.cpu(), cmap='hot')
|
121 |
-
axs[0, 0].set_title("Source Map")
|
122 |
-
axs[0, 1].imshow(target_map.cpu(), cmap='hot')
|
123 |
-
axs[0, 1].set_title("Target Map")
|
124 |
-
axs[1, 0].imshow(blended_source.permute(1, 2, 0).cpu())
|
125 |
-
axs[1, 0].set_title("Blended Source")
|
126 |
-
axs[1, 1].imshow(blended_target.permute(1, 2, 0).cpu())
|
127 |
-
axs[1, 1].set_title("Blended Target")
|
128 |
-
for ax in axs.flat:
|
129 |
-
ax.axis('off')
|
130 |
-
|
131 |
-
plt.tight_layout()
|
132 |
-
plt.close(fig)
|
133 |
-
return fig
|
134 |
|
135 |
with gr.Blocks() as demo:
|
136 |
# title
|
@@ -142,9 +168,24 @@ with gr.Blocks() as demo:
|
|
142 |
"The blended source and target images show the source and target images with the source and target maps overlaid, respectively. "
|
143 |
|
144 |
"For further information, please contact me on X (formerly Twitter): @Yeq6X.")
|
|
|
|
|
145 |
|
146 |
input_image = gr.ImageEditor(label="Cropped Image", elem_id="input_image", crop_size=(112, 112), show_fullscreen_button=True)
|
147 |
-
gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
get_heatmaps,
|
149 |
inputs=[
|
150 |
gr.Slider(0, batch_size - 1, step=1, label="Source Image Index"),
|
@@ -152,8 +193,9 @@ with gr.Blocks() as demo:
|
|
152 |
gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="Y Coordinate"),
|
153 |
input_image
|
154 |
],
|
155 |
-
outputs=
|
156 |
live=True,
|
|
|
157 |
)
|
158 |
# examples
|
159 |
gr.Markdown("# Examples")
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
from torch.utils.data import DataLoader
|
5 |
import matplotlib.pyplot as plt
|
6 |
from model_module import AutoencoderModule
|
|
|
7 |
import numpy as np
|
8 |
from PIL import Image
|
9 |
import base64
|
10 |
from io import BytesIO
|
11 |
|
12 |
+
import dataset
|
13 |
+
from dataset import MyDataset, ImageKeypointDataset, load_filenames, load_keypoints
|
14 |
+
import utils
|
15 |
+
|
16 |
+
try:
|
17 |
+
import spaces
|
18 |
+
except ImportError:
|
19 |
+
print("Spaces is not installed.")
|
20 |
+
|
21 |
+
image_size = 112
|
22 |
+
batch_size = 32
|
23 |
+
|
24 |
+
|
25 |
# モデルとデータの読み込み
|
26 |
+
def load_model(model_path="checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt", feature_dim=64):
|
|
|
|
|
27 |
model = AutoencoderModule(feature_dim=feature_dim)
|
28 |
state_dict = torch.load(model_path)
|
29 |
|
30 |
+
if "state_dict" in state_dict:
|
31 |
+
model.load_state_dict(state_dict['state_dict'])
|
32 |
+
model.eval()
|
33 |
+
else:
|
34 |
+
# state_dict のキーを修正
|
35 |
+
new_state_dict = {}
|
36 |
+
for key in state_dict:
|
37 |
+
new_key = "model." + key
|
38 |
+
new_state_dict[new_key] = state_dict[key]
|
39 |
+
model.load_state_dict(new_state_dict)
|
40 |
+
model.eval()
|
41 |
|
42 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
model.to(device)
|
44 |
print("Model loaded successfully.")
|
45 |
return model, device
|
46 |
|
47 |
+
def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=256):
|
48 |
filenames = load_filenames(img_dir)
|
49 |
train_X = filenames[:1000]
|
50 |
+
|
51 |
train_ds = MyDataset(train_X, img_dir=img_dir, img_size=image_size)
|
52 |
|
53 |
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
iterator = iter(train_loader)
|
56 |
x, _, _ = next(iterator)
|
57 |
x = x.to(device)
|
58 |
x = x[:,0].to(device)
|
59 |
+
|
60 |
print("Data loaded successfully.")
|
61 |
return x
|
62 |
|
63 |
+
def load_keypoints(device, img_dir="resources/trainB/", image_size=112, batch_size=32):
|
64 |
+
filenames = load_filenames(img_dir)
|
65 |
+
train_X = filenames[:1000]
|
66 |
+
keypoints = dataset.load_keypoints('resources/DataList.json')
|
67 |
+
|
68 |
+
image_points_ds = ImageKeypointDataset(train_X, keypoints, img_dir='resources/trainB/', img_size=image_size)
|
69 |
+
|
70 |
+
image_points_loader = DataLoader(image_points_ds, batch_size=batch_size, shuffle=False)
|
71 |
+
|
72 |
+
iterator = iter(image_points_loader)
|
73 |
+
test_imgs, points = next(iterator)
|
74 |
+
test_imgs = test_imgs.to(device)
|
75 |
+
points = points.to(device)*(image_size)
|
76 |
+
|
77 |
+
print("Keypoints loaded successfully.")
|
78 |
+
return test_imgs, points
|
79 |
|
80 |
# ヒートマップの生成関数
|
81 |
+
try:
|
82 |
+
@spaces.GPU
|
83 |
+
def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
84 |
+
return _get_heatmaps(source_num, x_coords, y_coords, uploaded_image)
|
85 |
+
except:
|
86 |
+
def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
87 |
+
return _get_heatmaps(source_num, x_coords, y_coords, uploaded_image)
|
88 |
+
|
89 |
+
def _get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
90 |
if type(uploaded_image) == str:
|
91 |
uploaded_image = Image.open(uploaded_image)
|
92 |
if type(source_num) == str:
|
|
|
95 |
x_coords = int(x_coords)
|
96 |
if type(y_coords) == str:
|
97 |
y_coords = int(y_coords)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
+
dec5, _ = model(x)
|
100 |
+
feature_map = dec5
|
101 |
+
# アップロード画像の前処理
|
102 |
+
if uploaded_image is not None:
|
103 |
+
uploaded_image = utils.preprocess_uploaded_image(uploaded_image['composite'], image_size)
|
104 |
+
else:
|
105 |
+
uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
|
106 |
+
target_feature_map, _ = model(uploaded_image)
|
107 |
+
img = torch.cat((x, uploaded_image))
|
108 |
+
feature_map = torch.cat((feature_map, target_feature_map))
|
109 |
+
|
110 |
+
source_map, target_map, blended_source, blended_target = utils.get_heatmaps(img, feature_map, source_num, x_coords, y_coords, uploaded_image)
|
111 |
+
keypoint_maps, blended_tensors = utils.get_keypoint_heatmaps(target_feature_map, mean_vector_list, points.size(1), uploaded_image)
|
112 |
|
113 |
+
# Matplotlibでプロットして画像として保存
|
114 |
+
fig, axs = plt.subplots(2, 3, figsize=(10, 6))
|
115 |
+
axs[0, 0].imshow(source_map, cmap='hot')
|
116 |
+
axs[0, 0].set_title("Source Map")
|
117 |
+
axs[0, 1].imshow(target_map, cmap='hot')
|
118 |
+
axs[0, 1].set_title("Target Map")
|
119 |
+
axs[0, 2].imshow(keypoint_maps[0], cmap='hot')
|
120 |
+
axs[0, 2].set_title("Keypoint Map")
|
121 |
+
axs[1, 0].imshow(blended_source.permute(1, 2, 0))
|
122 |
+
axs[1, 0].set_title("Blended Source")
|
123 |
+
axs[1, 1].imshow(blended_target.permute(1, 2, 0))
|
124 |
+
axs[1, 1].set_title("Blended Target")
|
125 |
+
axs[1, 2].imshow(blended_tensors[0].permute(1, 2, 0))
|
126 |
+
axs[1, 2].set_title("Blended Keypoint")
|
127 |
+
for ax in axs.flat:
|
128 |
+
ax.axis('off')
|
129 |
+
|
130 |
+
plt.tight_layout()
|
131 |
+
plt.close(fig)
|
132 |
+
return fig
|
133 |
|
134 |
+
def setup(model_dict, input_image=None):
|
135 |
+
global model, device, x, test_imgs, points, mean_vector_list
|
136 |
+
# str -> dictに変換
|
137 |
+
if type(model_dict) == str:
|
138 |
+
model_dict = eval(model_dict)
|
139 |
+
model_name = model_dict["name"]
|
140 |
+
feature_dim = model_dict["feature_dim"]
|
141 |
+
model_path = f"checkpoints/{model_name}"
|
142 |
+
model, device = load_model(model_path, feature_dim)
|
143 |
+
x = load_data(device)
|
144 |
+
test_imgs, points = load_keypoints(device)
|
145 |
+
feature_map, _ = model(test_imgs)
|
146 |
+
mean_vector_list = utils.get_mean_vector(feature_map, points)
|
147 |
+
|
148 |
+
if input_image is not None:
|
149 |
+
fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
|
150 |
+
return fig
|
151 |
|
|
|
152 |
|
153 |
+
models = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
|
154 |
+
{"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
|
155 |
+
{"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
|
156 |
+
{"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
|
157 |
+
|
158 |
+
setup(models[0])
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
with gr.Blocks() as demo:
|
162 |
# title
|
|
|
168 |
"The blended source and target images show the source and target images with the source and target maps overlaid, respectively. "
|
169 |
|
170 |
"For further information, please contact me on X (formerly Twitter): @Yeq6X.")
|
171 |
+
|
172 |
+
gr.Markdown("## Heatmap Visualization")
|
173 |
|
174 |
input_image = gr.ImageEditor(label="Cropped Image", elem_id="input_image", crop_size=(112, 112), show_fullscreen_button=True)
|
175 |
+
output_plot = gr.Plot(value=None, elem_id="output_plot", show_label=False)
|
176 |
+
with gr.Row():
|
177 |
+
with gr.Column():
|
178 |
+
with gr.Row():
|
179 |
+
model_name = gr.Dropdown(
|
180 |
+
choices=[str(model) for model in models],
|
181 |
+
container=False
|
182 |
+
)
|
183 |
+
load_button = gr.Button("Load Model")
|
184 |
+
load_button.click(setup, inputs=[model_name, input_image], outputs=[output_plot])
|
185 |
+
with gr.Row():
|
186 |
+
pass
|
187 |
+
|
188 |
+
inference = gr.Interface(
|
189 |
get_heatmaps,
|
190 |
inputs=[
|
191 |
gr.Slider(0, batch_size - 1, step=1, label="Source Image Index"),
|
|
|
193 |
gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="Y Coordinate"),
|
194 |
input_image
|
195 |
],
|
196 |
+
outputs=output_plot,
|
197 |
live=True,
|
198 |
+
flagging_mode="never"
|
199 |
)
|
200 |
# examples
|
201 |
gr.Markdown("# Examples")
|
dataset.py
CHANGED
@@ -4,15 +4,48 @@ from torchvision import transforms
|
|
4 |
import random
|
5 |
from PIL import Image
|
6 |
import os
|
|
|
7 |
|
8 |
from utils import RandomAffineAndRetMat
|
9 |
|
10 |
def load_filenames(data_dir):
|
|
|
11 |
img_exts = ['.jpg', '.jpeg', '.png', '.bmp', '.ppm', '.pgm', '.tif', '.tiff']
|
12 |
filenames = [f for f in os.listdir(data_dir) if os.path.splitext(f)[1].lower() in img_exts]
|
13 |
|
14 |
return filenames
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
class MyDataset:
|
18 |
def __init__(self, X, valid=False, img_dir='resources/trainB/', img_size=256):
|
@@ -64,4 +97,48 @@ class MyDataset:
|
|
64 |
|
65 |
X = torch.stack(xlist)
|
66 |
mat = torch.stack(matlist)
|
67 |
-
return X, mat, f
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import random
|
5 |
from PIL import Image
|
6 |
import os
|
7 |
+
import pandas as pd
|
8 |
|
9 |
from utils import RandomAffineAndRetMat
|
10 |
|
11 |
def load_filenames(data_dir):
|
12 |
+
# 画像の拡張子のみ
|
13 |
img_exts = ['.jpg', '.jpeg', '.png', '.bmp', '.ppm', '.pgm', '.tif', '.tiff']
|
14 |
filenames = [f for f in os.listdir(data_dir) if os.path.splitext(f)[1].lower() in img_exts]
|
15 |
|
16 |
return filenames
|
17 |
|
18 |
+
def load_keypoints(label_path):
|
19 |
+
label_data = pd.read_json(label_path)
|
20 |
+
label_data = label_data.sort_index()
|
21 |
+
tmp_points = []
|
22 |
+
|
23 |
+
for o in label_data.data[0:1000]:
|
24 |
+
tmps = []
|
25 |
+
for i in range(60):
|
26 |
+
tmps.append(o['points'][str(i)]['x'])
|
27 |
+
tmps.append(o['points'][str(i)]['y'])
|
28 |
+
tmp_points.append(tmps) # datanum
|
29 |
+
|
30 |
+
df_points = pd.DataFrame(tmp_points)
|
31 |
+
df_points = df_points.iloc[:,[
|
32 |
+
*list(range(0,16*2+1,4)), *list(range(1,16*2+2,4)),
|
33 |
+
*list(range(27*2,36*2+1,4)), *list(range(27*2+1,36*2+2,4)),
|
34 |
+
*list(range(37*2,46*2+1,4)), *list(range(37*2+1,46*2+2,4)),
|
35 |
+
# 49*2, 49*2+1,
|
36 |
+
# *list(range(50*2,55*2+1,4)), *list(range(50*2+1,55*2+2,4)),
|
37 |
+
28*2, 28*2+1,
|
38 |
+
30*2, 30*2+1,
|
39 |
+
34*2, 34*2+1,
|
40 |
+
38*2, 38*2+1,
|
41 |
+
40*2, 40*2+1,
|
42 |
+
44*2, 44*2+1,
|
43 |
+
]]
|
44 |
+
df_points = df_points.sort_index(axis=1)
|
45 |
+
df_points.columns = list(range(len(df_points.columns)))
|
46 |
+
# df_points[0:500].iloc[0]
|
47 |
+
|
48 |
+
return df_points
|
49 |
|
50 |
class MyDataset:
|
51 |
def __init__(self, X, valid=False, img_dir='resources/trainB/', img_size=256):
|
|
|
97 |
|
98 |
X = torch.stack(xlist)
|
99 |
mat = torch.stack(matlist)
|
100 |
+
return X, mat, f
|
101 |
+
|
102 |
+
class ImageKeypointDataset:
|
103 |
+
def __init__(self, X, y, valid=False, img_dir='resources/trainB/', img_size=256):
|
104 |
+
self.X = X
|
105 |
+
self.y = y
|
106 |
+
self.valid = valid
|
107 |
+
self.img_dir = img_dir
|
108 |
+
self.img_size = img_size
|
109 |
+
# if not valid:
|
110 |
+
trans = [
|
111 |
+
transforms.Resize(self.img_size),
|
112 |
+
transforms.ToTensor(),
|
113 |
+
# transforms.Normalize(mean=means, std=stds),
|
114 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
115 |
+
# transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
116 |
+
]
|
117 |
+
self.trans = transforms.Compose(trans)
|
118 |
+
|
119 |
+
def __len__(self):
|
120 |
+
return len(self.X)
|
121 |
+
|
122 |
+
def __getitem__(self, index):
|
123 |
+
if type(index) is slice:
|
124 |
+
if index.step is None:
|
125 |
+
return (torch.stack([self.get_one_X(i) for i in range(index.start, index.stop)]),
|
126 |
+
torch.stack([self.get_one_y(i) for i in range(index.start, index.stop)]))
|
127 |
+
else:
|
128 |
+
return (torch.stack([self.get_one_X(i) for i in range(index.start, index.stop, index.step)]),
|
129 |
+
torch.stack([self.get_one_y(i) for i in range(index.start, index.stop, index.step)]))
|
130 |
+
if type(index) is int:
|
131 |
+
return self.get_one_X(index), self.get_one_y(index)
|
132 |
+
|
133 |
+
def get_one_X(self, index):
|
134 |
+
f = self.img_dir + self.X[index]
|
135 |
+
X = Image.open(f)
|
136 |
+
X = self.trans(X)
|
137 |
+
return X
|
138 |
+
|
139 |
+
def get_one_y(self, index):
|
140 |
+
y = self.y.iloc[index].copy()
|
141 |
+
y = torch.tensor(y)
|
142 |
+
y = y.float()
|
143 |
+
y = y.reshape(25,2)
|
144 |
+
return y
|
resources/DataList.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils.py
CHANGED
@@ -77,3 +77,93 @@ class RandomAffineAndRetMat(torch.nn.Module):
|
|
77 |
affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix)
|
78 |
|
79 |
return affine_matrix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix)
|
78 |
|
79 |
return affine_matrix
|
80 |
+
|
81 |
+
def norm_img(img):
|
82 |
+
return (img-img.min())/(img.max()-img.min())
|
83 |
+
|
84 |
+
def preprocess_uploaded_image(uploaded_image, image_size):
|
85 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
86 |
+
# ndarrayの場合はPILイメージに変換
|
87 |
+
if type(uploaded_image) == np.ndarray:
|
88 |
+
uploaded_image = Image.fromarray(uploaded_image)
|
89 |
+
uploaded_image = uploaded_image.convert("RGB")
|
90 |
+
uploaded_image = uploaded_image.resize((image_size, image_size))
|
91 |
+
uploaded_image = np.array(uploaded_image).transpose(2, 0, 1) / 255.0
|
92 |
+
uploaded_image = torch.tensor(uploaded_image, dtype=torch.float32).unsqueeze(0).to(device)
|
93 |
+
return uploaded_image
|
94 |
+
|
95 |
+
def get_heatmaps(img, feature_map, source_num, x_coords, y_coords, uploaded_image):
|
96 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
97 |
+
image_size = img.size(2)
|
98 |
+
|
99 |
+
batch_size = feature_map.size(0)
|
100 |
+
feature_dim = feature_map.size(1)
|
101 |
+
|
102 |
+
target_num = batch_size - 1
|
103 |
+
|
104 |
+
x_coords = [x_coords] * batch_size
|
105 |
+
y_coords = [y_coords] * batch_size
|
106 |
+
|
107 |
+
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords]
|
108 |
+
vector = vectors[source_num]
|
109 |
+
|
110 |
+
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
|
111 |
+
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
|
112 |
+
|
113 |
+
norm_batch_distance_map = 1 / torch.cosh(20 * (batch_distance_map - batch_distance_map.min()) / (batch_distance_map.max() - batch_distance_map.min())) ** 2
|
114 |
+
|
115 |
+
source_map = norm_batch_distance_map[source_num].detach().cpu()
|
116 |
+
target_map = norm_batch_distance_map[target_num].detach().cpu()
|
117 |
+
|
118 |
+
alpha = 0.7
|
119 |
+
blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num] / norm_batch_distance_map[source_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
|
120 |
+
blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num] / norm_batch_distance_map[target_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
|
121 |
+
|
122 |
+
blended_source = blended_source.detach().cpu()
|
123 |
+
blended_target = blended_target.detach().cpu()
|
124 |
+
|
125 |
+
return source_map, target_map, blended_source, blended_target
|
126 |
+
|
127 |
+
def get_mean_vector(feature_map, points):
|
128 |
+
keypoints_size = points.size(1)
|
129 |
+
|
130 |
+
mean_vector_list = []
|
131 |
+
for i in range(keypoints_size):
|
132 |
+
x_coords, y_coords = torch.round(points[:,i].t()).to(torch.long)
|
133 |
+
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
|
134 |
+
# mean_vector = vectors[0:10].mean(0) # 10個の特徴マップの平均ベクトルを取得
|
135 |
+
mean_vector = vectors.mean(0)
|
136 |
+
mean_vector_list.append(mean_vector)
|
137 |
+
return mean_vector_list
|
138 |
+
|
139 |
+
def get_keypoint_heatmaps(feature_map, mean_vector_list, keypoints_size, imgs):
|
140 |
+
if len(feature_map.size()) == 3:
|
141 |
+
feature_map = feature_map.unsqueeze(0)
|
142 |
+
device = feature_map.device
|
143 |
+
batch_size = feature_map.size(0)
|
144 |
+
feature_dim = feature_map.size(1)
|
145 |
+
size = feature_map.size(2)
|
146 |
+
|
147 |
+
norm_batch_distance_map = torch.zeros(batch_size,size,size,device=device)
|
148 |
+
for i in range(keypoints_size):
|
149 |
+
vector = mean_vector_list[i]
|
150 |
+
reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
|
151 |
+
|
152 |
+
batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), size, size)
|
153 |
+
batch_distance_map = 1/torch.cosh( 40*(batch_distance_map-batch_distance_map.min())
|
154 |
+
/(batch_distance_map.max()-batch_distance_map.min()) )**2
|
155 |
+
# 正規化
|
156 |
+
m = batch_distance_map/batch_distance_map.max(1).values.max(1).values.unsqueeze(0).unsqueeze(0).repeat(112,112,1).permute(2,0,1)
|
157 |
+
norm_batch_distance_map += m
|
158 |
+
# 1以上を消す
|
159 |
+
norm_batch_distance_map = (-F.relu(-norm_batch_distance_map+1)+1)
|
160 |
+
keypoint_maps = norm_batch_distance_map.detach().cpu()
|
161 |
+
|
162 |
+
alpha = 0.8 # Transparency factor for the heatmap overlay
|
163 |
+
blended_tensors = (1 - alpha) * imgs + alpha * torch.cat(
|
164 |
+
(norm_batch_distance_map.unsqueeze(1), torch.zeros(batch_size,2,size,size,device=device)),
|
165 |
+
dim=1
|
166 |
+
)
|
167 |
+
blended_tensors = norm_img(blended_tensors).detach().cpu()
|
168 |
+
|
169 |
+
return keypoint_maps, blended_tensors
|