yeq6x commited on
Commit
19d010a
·
1 Parent(s): 2e99060
Files changed (4) hide show
  1. app.py +128 -86
  2. dataset.py +78 -1
  3. resources/DataList.json +0 -0
  4. utils.py +90 -0
app.py CHANGED
@@ -1,74 +1,92 @@
1
  import gradio as gr
2
- import spaces
3
  import torch
4
  import torch.nn.functional as F
5
  from torch.utils.data import DataLoader
6
  import matplotlib.pyplot as plt
7
  from model_module import AutoencoderModule
8
- from dataset import MyDataset, load_filenames
9
  import numpy as np
10
  from PIL import Image
11
  import base64
12
  from io import BytesIO
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # モデルとデータの読み込み
15
- def load_model():
16
- model_path = "checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt"
17
- feature_dim = 64
18
  model = AutoencoderModule(feature_dim=feature_dim)
19
  state_dict = torch.load(model_path)
20
 
21
- # # state_dict のキーを修正
22
- # new_state_dict = {}
23
- # for key in state_dict:
24
- # new_key = "model." + key
25
- # new_state_dict[new_key] = state_dict[key]
26
- model.load_state_dict(state_dict['state_dict'])
27
- model.eval()
 
 
 
 
28
 
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  model.to(device)
31
  print("Model loaded successfully.")
32
  return model, device
33
 
34
- def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=32):
35
  filenames = load_filenames(img_dir)
36
  train_X = filenames[:1000]
 
37
  train_ds = MyDataset(train_X, img_dir=img_dir, img_size=image_size)
38
 
39
- train_loader = DataLoader(
40
- train_ds,
41
- batch_size=batch_size,
42
- shuffle=True,
43
- num_workers=0,
44
- )
45
 
46
  iterator = iter(train_loader)
47
  x, _, _ = next(iterator)
48
  x = x.to(device)
49
  x = x[:,0].to(device)
 
50
  print("Data loaded successfully.")
51
  return x
52
 
53
- model, device = load_model()
54
- image_size = 112
55
- batch_size = 32
56
- x = load_data(device)
57
-
58
- # アップロード画像の前処理
59
- def preprocess_uploaded_image(uploaded_image, image_size):
60
- # ndarrayの場合はPILイメージに変換
61
- if type(uploaded_image) == np.ndarray:
62
- uploaded_image = Image.fromarray(uploaded_image)
63
- uploaded_image = uploaded_image.convert("RGB")
64
- uploaded_image = uploaded_image.resize((image_size, image_size))
65
- uploaded_image = np.array(uploaded_image).transpose(2, 0, 1) / 255.0
66
- uploaded_image = torch.tensor(uploaded_image, dtype=torch.float32).unsqueeze(0).to(device)
67
- return uploaded_image
 
68
 
69
  # ヒートマップの生成関数
70
- @spaces.GPU
71
- def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
 
 
 
 
 
 
 
72
  if type(uploaded_image) == str:
73
  uploaded_image = Image.open(uploaded_image)
74
  if type(source_num) == str:
@@ -77,60 +95,68 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
77
  x_coords = int(x_coords)
78
  if type(y_coords) == str:
79
  y_coords = int(y_coords)
80
-
81
- with torch.no_grad():
82
- dec5, _ = model(x)
83
- img = x
84
- feature_map = dec5
85
- batch_size = feature_map.size(0)
86
- feature_dim = feature_map.size(1)
87
-
88
- # アップロード画像の前処理
89
- if uploaded_image is not None:
90
- uploaded_image = preprocess_uploaded_image(uploaded_image['composite'], image_size)
91
- target_feature_map, _ = model(uploaded_image)
92
- img = torch.cat((img, uploaded_image))
93
- feature_map = torch.cat((feature_map, target_feature_map))
94
- batch_size += 1
95
- else:
96
- uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
97
-
98
- target_num = batch_size - 1
99
 
100
- x_coords = [x_coords] * batch_size
101
- y_coords = [y_coords] * batch_size
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords]
104
- vector = vectors[source_num]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
107
- batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- norm_batch_distance_map = 1 / torch.cosh(20 * (batch_distance_map - batch_distance_map.min()) / (batch_distance_map.max() - batch_distance_map.min())) ** 2
110
 
111
- source_map = norm_batch_distance_map[source_num]
112
- target_map = norm_batch_distance_map[target_num]
 
 
 
 
113
 
114
- alpha = 0.7
115
- blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num] / norm_batch_distance_map[source_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
116
- blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num] / norm_batch_distance_map[target_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
117
-
118
- # Matplotlibでプロットして画像として保存
119
- fig, axs = plt.subplots(2, 2, figsize=(10, 10))
120
- axs[0, 0].imshow(source_map.cpu(), cmap='hot')
121
- axs[0, 0].set_title("Source Map")
122
- axs[0, 1].imshow(target_map.cpu(), cmap='hot')
123
- axs[0, 1].set_title("Target Map")
124
- axs[1, 0].imshow(blended_source.permute(1, 2, 0).cpu())
125
- axs[1, 0].set_title("Blended Source")
126
- axs[1, 1].imshow(blended_target.permute(1, 2, 0).cpu())
127
- axs[1, 1].set_title("Blended Target")
128
- for ax in axs.flat:
129
- ax.axis('off')
130
-
131
- plt.tight_layout()
132
- plt.close(fig)
133
- return fig
134
 
135
  with gr.Blocks() as demo:
136
  # title
@@ -142,9 +168,24 @@ with gr.Blocks() as demo:
142
  "The blended source and target images show the source and target images with the source and target maps overlaid, respectively. "
143
 
144
  "For further information, please contact me on X (formerly Twitter): @Yeq6X.")
 
 
145
 
146
  input_image = gr.ImageEditor(label="Cropped Image", elem_id="input_image", crop_size=(112, 112), show_fullscreen_button=True)
147
- gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  get_heatmaps,
149
  inputs=[
150
  gr.Slider(0, batch_size - 1, step=1, label="Source Image Index"),
@@ -152,8 +193,9 @@ with gr.Blocks() as demo:
152
  gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="Y Coordinate"),
153
  input_image
154
  ],
155
- outputs="plot",
156
  live=True,
 
157
  )
158
  # examples
159
  gr.Markdown("# Examples")
 
1
  import gradio as gr
 
2
  import torch
3
  import torch.nn.functional as F
4
  from torch.utils.data import DataLoader
5
  import matplotlib.pyplot as plt
6
  from model_module import AutoencoderModule
 
7
  import numpy as np
8
  from PIL import Image
9
  import base64
10
  from io import BytesIO
11
 
12
+ import dataset
13
+ from dataset import MyDataset, ImageKeypointDataset, load_filenames, load_keypoints
14
+ import utils
15
+
16
+ try:
17
+ import spaces
18
+ except ImportError:
19
+ print("Spaces is not installed.")
20
+
21
+ image_size = 112
22
+ batch_size = 32
23
+
24
+
25
  # モデルとデータの読み込み
26
+ def load_model(model_path="checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt", feature_dim=64):
 
 
27
  model = AutoencoderModule(feature_dim=feature_dim)
28
  state_dict = torch.load(model_path)
29
 
30
+ if "state_dict" in state_dict:
31
+ model.load_state_dict(state_dict['state_dict'])
32
+ model.eval()
33
+ else:
34
+ # state_dict のキーを修正
35
+ new_state_dict = {}
36
+ for key in state_dict:
37
+ new_key = "model." + key
38
+ new_state_dict[new_key] = state_dict[key]
39
+ model.load_state_dict(new_state_dict)
40
+ model.eval()
41
 
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
  model.to(device)
44
  print("Model loaded successfully.")
45
  return model, device
46
 
47
+ def load_data(device, img_dir="resources/trainB/", image_size=112, batch_size=256):
48
  filenames = load_filenames(img_dir)
49
  train_X = filenames[:1000]
50
+
51
  train_ds = MyDataset(train_X, img_dir=img_dir, img_size=image_size)
52
 
53
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
 
 
 
 
 
54
 
55
  iterator = iter(train_loader)
56
  x, _, _ = next(iterator)
57
  x = x.to(device)
58
  x = x[:,0].to(device)
59
+
60
  print("Data loaded successfully.")
61
  return x
62
 
63
+ def load_keypoints(device, img_dir="resources/trainB/", image_size=112, batch_size=32):
64
+ filenames = load_filenames(img_dir)
65
+ train_X = filenames[:1000]
66
+ keypoints = dataset.load_keypoints('resources/DataList.json')
67
+
68
+ image_points_ds = ImageKeypointDataset(train_X, keypoints, img_dir='resources/trainB/', img_size=image_size)
69
+
70
+ image_points_loader = DataLoader(image_points_ds, batch_size=batch_size, shuffle=False)
71
+
72
+ iterator = iter(image_points_loader)
73
+ test_imgs, points = next(iterator)
74
+ test_imgs = test_imgs.to(device)
75
+ points = points.to(device)*(image_size)
76
+
77
+ print("Keypoints loaded successfully.")
78
+ return test_imgs, points
79
 
80
  # ヒートマップの生成関数
81
+ try:
82
+ @spaces.GPU
83
+ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
84
+ return _get_heatmaps(source_num, x_coords, y_coords, uploaded_image)
85
+ except:
86
+ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
87
+ return _get_heatmaps(source_num, x_coords, y_coords, uploaded_image)
88
+
89
+ def _get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
90
  if type(uploaded_image) == str:
91
  uploaded_image = Image.open(uploaded_image)
92
  if type(source_num) == str:
 
95
  x_coords = int(x_coords)
96
  if type(y_coords) == str:
97
  y_coords = int(y_coords)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ dec5, _ = model(x)
100
+ feature_map = dec5
101
+ # アップロード画像の前処理
102
+ if uploaded_image is not None:
103
+ uploaded_image = utils.preprocess_uploaded_image(uploaded_image['composite'], image_size)
104
+ else:
105
+ uploaded_image = torch.zeros(1, 3, image_size, image_size, device=device)
106
+ target_feature_map, _ = model(uploaded_image)
107
+ img = torch.cat((x, uploaded_image))
108
+ feature_map = torch.cat((feature_map, target_feature_map))
109
+
110
+ source_map, target_map, blended_source, blended_target = utils.get_heatmaps(img, feature_map, source_num, x_coords, y_coords, uploaded_image)
111
+ keypoint_maps, blended_tensors = utils.get_keypoint_heatmaps(target_feature_map, mean_vector_list, points.size(1), uploaded_image)
112
 
113
+ # Matplotlibでプロットして画像として保存
114
+ fig, axs = plt.subplots(2, 3, figsize=(10, 6))
115
+ axs[0, 0].imshow(source_map, cmap='hot')
116
+ axs[0, 0].set_title("Source Map")
117
+ axs[0, 1].imshow(target_map, cmap='hot')
118
+ axs[0, 1].set_title("Target Map")
119
+ axs[0, 2].imshow(keypoint_maps[0], cmap='hot')
120
+ axs[0, 2].set_title("Keypoint Map")
121
+ axs[1, 0].imshow(blended_source.permute(1, 2, 0))
122
+ axs[1, 0].set_title("Blended Source")
123
+ axs[1, 1].imshow(blended_target.permute(1, 2, 0))
124
+ axs[1, 1].set_title("Blended Target")
125
+ axs[1, 2].imshow(blended_tensors[0].permute(1, 2, 0))
126
+ axs[1, 2].set_title("Blended Keypoint")
127
+ for ax in axs.flat:
128
+ ax.axis('off')
129
+
130
+ plt.tight_layout()
131
+ plt.close(fig)
132
+ return fig
133
 
134
+ def setup(model_dict, input_image=None):
135
+ global model, device, x, test_imgs, points, mean_vector_list
136
+ # str -> dictに変換
137
+ if type(model_dict) == str:
138
+ model_dict = eval(model_dict)
139
+ model_name = model_dict["name"]
140
+ feature_dim = model_dict["feature_dim"]
141
+ model_path = f"checkpoints/{model_name}"
142
+ model, device = load_model(model_path, feature_dim)
143
+ x = load_data(device)
144
+ test_imgs, points = load_keypoints(device)
145
+ feature_map, _ = model(test_imgs)
146
+ mean_vector_list = utils.get_mean_vector(feature_map, points)
147
+
148
+ if input_image is not None:
149
+ fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
150
+ return fig
151
 
 
152
 
153
+ models = [{"name": "ae_model_tf_2024-03-05_00-35-21.pth", "feature_dim": 32},
154
+ {"name": "autoencoder-epoch=09-train_loss=1.00.ckpt", "feature_dim": 64},
155
+ {"name": "autoencoder-epoch=29-train_loss=1.01.ckpt", "feature_dim": 64},
156
+ {"name": "autoencoder-epoch=49-train_loss=1.01.ckpt", "feature_dim": 64}]
157
+
158
+ setup(models[0])
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  with gr.Blocks() as demo:
162
  # title
 
168
  "The blended source and target images show the source and target images with the source and target maps overlaid, respectively. "
169
 
170
  "For further information, please contact me on X (formerly Twitter): @Yeq6X.")
171
+
172
+ gr.Markdown("## Heatmap Visualization")
173
 
174
  input_image = gr.ImageEditor(label="Cropped Image", elem_id="input_image", crop_size=(112, 112), show_fullscreen_button=True)
175
+ output_plot = gr.Plot(value=None, elem_id="output_plot", show_label=False)
176
+ with gr.Row():
177
+ with gr.Column():
178
+ with gr.Row():
179
+ model_name = gr.Dropdown(
180
+ choices=[str(model) for model in models],
181
+ container=False
182
+ )
183
+ load_button = gr.Button("Load Model")
184
+ load_button.click(setup, inputs=[model_name, input_image], outputs=[output_plot])
185
+ with gr.Row():
186
+ pass
187
+
188
+ inference = gr.Interface(
189
  get_heatmaps,
190
  inputs=[
191
  gr.Slider(0, batch_size - 1, step=1, label="Source Image Index"),
 
193
  gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="Y Coordinate"),
194
  input_image
195
  ],
196
+ outputs=output_plot,
197
  live=True,
198
+ flagging_mode="never"
199
  )
200
  # examples
201
  gr.Markdown("# Examples")
dataset.py CHANGED
@@ -4,15 +4,48 @@ from torchvision import transforms
4
  import random
5
  from PIL import Image
6
  import os
 
7
 
8
  from utils import RandomAffineAndRetMat
9
 
10
  def load_filenames(data_dir):
 
11
  img_exts = ['.jpg', '.jpeg', '.png', '.bmp', '.ppm', '.pgm', '.tif', '.tiff']
12
  filenames = [f for f in os.listdir(data_dir) if os.path.splitext(f)[1].lower() in img_exts]
13
 
14
  return filenames
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  class MyDataset:
18
  def __init__(self, X, valid=False, img_dir='resources/trainB/', img_size=256):
@@ -64,4 +97,48 @@ class MyDataset:
64
 
65
  X = torch.stack(xlist)
66
  mat = torch.stack(matlist)
67
- return X, mat, f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import random
5
  from PIL import Image
6
  import os
7
+ import pandas as pd
8
 
9
  from utils import RandomAffineAndRetMat
10
 
11
  def load_filenames(data_dir):
12
+ # 画像の拡張子のみ
13
  img_exts = ['.jpg', '.jpeg', '.png', '.bmp', '.ppm', '.pgm', '.tif', '.tiff']
14
  filenames = [f for f in os.listdir(data_dir) if os.path.splitext(f)[1].lower() in img_exts]
15
 
16
  return filenames
17
 
18
+ def load_keypoints(label_path):
19
+ label_data = pd.read_json(label_path)
20
+ label_data = label_data.sort_index()
21
+ tmp_points = []
22
+
23
+ for o in label_data.data[0:1000]:
24
+ tmps = []
25
+ for i in range(60):
26
+ tmps.append(o['points'][str(i)]['x'])
27
+ tmps.append(o['points'][str(i)]['y'])
28
+ tmp_points.append(tmps) # datanum
29
+
30
+ df_points = pd.DataFrame(tmp_points)
31
+ df_points = df_points.iloc[:,[
32
+ *list(range(0,16*2+1,4)), *list(range(1,16*2+2,4)),
33
+ *list(range(27*2,36*2+1,4)), *list(range(27*2+1,36*2+2,4)),
34
+ *list(range(37*2,46*2+1,4)), *list(range(37*2+1,46*2+2,4)),
35
+ # 49*2, 49*2+1,
36
+ # *list(range(50*2,55*2+1,4)), *list(range(50*2+1,55*2+2,4)),
37
+ 28*2, 28*2+1,
38
+ 30*2, 30*2+1,
39
+ 34*2, 34*2+1,
40
+ 38*2, 38*2+1,
41
+ 40*2, 40*2+1,
42
+ 44*2, 44*2+1,
43
+ ]]
44
+ df_points = df_points.sort_index(axis=1)
45
+ df_points.columns = list(range(len(df_points.columns)))
46
+ # df_points[0:500].iloc[0]
47
+
48
+ return df_points
49
 
50
  class MyDataset:
51
  def __init__(self, X, valid=False, img_dir='resources/trainB/', img_size=256):
 
97
 
98
  X = torch.stack(xlist)
99
  mat = torch.stack(matlist)
100
+ return X, mat, f
101
+
102
+ class ImageKeypointDataset:
103
+ def __init__(self, X, y, valid=False, img_dir='resources/trainB/', img_size=256):
104
+ self.X = X
105
+ self.y = y
106
+ self.valid = valid
107
+ self.img_dir = img_dir
108
+ self.img_size = img_size
109
+ # if not valid:
110
+ trans = [
111
+ transforms.Resize(self.img_size),
112
+ transforms.ToTensor(),
113
+ # transforms.Normalize(mean=means, std=stds),
114
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
115
+ # transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
116
+ ]
117
+ self.trans = transforms.Compose(trans)
118
+
119
+ def __len__(self):
120
+ return len(self.X)
121
+
122
+ def __getitem__(self, index):
123
+ if type(index) is slice:
124
+ if index.step is None:
125
+ return (torch.stack([self.get_one_X(i) for i in range(index.start, index.stop)]),
126
+ torch.stack([self.get_one_y(i) for i in range(index.start, index.stop)]))
127
+ else:
128
+ return (torch.stack([self.get_one_X(i) for i in range(index.start, index.stop, index.step)]),
129
+ torch.stack([self.get_one_y(i) for i in range(index.start, index.stop, index.step)]))
130
+ if type(index) is int:
131
+ return self.get_one_X(index), self.get_one_y(index)
132
+
133
+ def get_one_X(self, index):
134
+ f = self.img_dir + self.X[index]
135
+ X = Image.open(f)
136
+ X = self.trans(X)
137
+ return X
138
+
139
+ def get_one_y(self, index):
140
+ y = self.y.iloc[index].copy()
141
+ y = torch.tensor(y)
142
+ y = y.float()
143
+ y = y.reshape(25,2)
144
+ return y
resources/DataList.json ADDED
The diff for this file is too large to render. See raw diff
 
utils.py CHANGED
@@ -77,3 +77,93 @@ class RandomAffineAndRetMat(torch.nn.Module):
77
  affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix)
78
 
79
  return affine_matrix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  affine_matrix = translation_matrix.mm(rotation_matrix).mm(scaling_matrix).mm(shearing_matrix)
78
 
79
  return affine_matrix
80
+
81
+ def norm_img(img):
82
+ return (img-img.min())/(img.max()-img.min())
83
+
84
+ def preprocess_uploaded_image(uploaded_image, image_size):
85
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86
+ # ndarrayの場合はPILイメージに変換
87
+ if type(uploaded_image) == np.ndarray:
88
+ uploaded_image = Image.fromarray(uploaded_image)
89
+ uploaded_image = uploaded_image.convert("RGB")
90
+ uploaded_image = uploaded_image.resize((image_size, image_size))
91
+ uploaded_image = np.array(uploaded_image).transpose(2, 0, 1) / 255.0
92
+ uploaded_image = torch.tensor(uploaded_image, dtype=torch.float32).unsqueeze(0).to(device)
93
+ return uploaded_image
94
+
95
+ def get_heatmaps(img, feature_map, source_num, x_coords, y_coords, uploaded_image):
96
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
97
+ image_size = img.size(2)
98
+
99
+ batch_size = feature_map.size(0)
100
+ feature_dim = feature_map.size(1)
101
+
102
+ target_num = batch_size - 1
103
+
104
+ x_coords = [x_coords] * batch_size
105
+ y_coords = [y_coords] * batch_size
106
+
107
+ vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords]
108
+ vector = vectors[source_num]
109
+
110
+ reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
111
+ batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), image_size, image_size)
112
+
113
+ norm_batch_distance_map = 1 / torch.cosh(20 * (batch_distance_map - batch_distance_map.min()) / (batch_distance_map.max() - batch_distance_map.min())) ** 2
114
+
115
+ source_map = norm_batch_distance_map[source_num].detach().cpu()
116
+ target_map = norm_batch_distance_map[target_num].detach().cpu()
117
+
118
+ alpha = 0.7
119
+ blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num] / norm_batch_distance_map[source_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
120
+ blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num] / norm_batch_distance_map[target_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
121
+
122
+ blended_source = blended_source.detach().cpu()
123
+ blended_target = blended_target.detach().cpu()
124
+
125
+ return source_map, target_map, blended_source, blended_target
126
+
127
+ def get_mean_vector(feature_map, points):
128
+ keypoints_size = points.size(1)
129
+
130
+ mean_vector_list = []
131
+ for i in range(keypoints_size):
132
+ x_coords, y_coords = torch.round(points[:,i].t()).to(torch.long)
133
+ vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
134
+ # mean_vector = vectors[0:10].mean(0) # 10個の特徴マップの平均ベクトルを取得
135
+ mean_vector = vectors.mean(0)
136
+ mean_vector_list.append(mean_vector)
137
+ return mean_vector_list
138
+
139
+ def get_keypoint_heatmaps(feature_map, mean_vector_list, keypoints_size, imgs):
140
+ if len(feature_map.size()) == 3:
141
+ feature_map = feature_map.unsqueeze(0)
142
+ device = feature_map.device
143
+ batch_size = feature_map.size(0)
144
+ feature_dim = feature_map.size(1)
145
+ size = feature_map.size(2)
146
+
147
+ norm_batch_distance_map = torch.zeros(batch_size,size,size,device=device)
148
+ for i in range(keypoints_size):
149
+ vector = mean_vector_list[i]
150
+ reshaped_feature_map = feature_map.permute(0, 2, 3, 1).view(feature_map.size(0), -1, feature_dim)
151
+
152
+ batch_distance_map = F.pairwise_distance(reshaped_feature_map, vector).view(feature_map.size(0), size, size)
153
+ batch_distance_map = 1/torch.cosh( 40*(batch_distance_map-batch_distance_map.min())
154
+ /(batch_distance_map.max()-batch_distance_map.min()) )**2
155
+ # 正規化
156
+ m = batch_distance_map/batch_distance_map.max(1).values.max(1).values.unsqueeze(0).unsqueeze(0).repeat(112,112,1).permute(2,0,1)
157
+ norm_batch_distance_map += m
158
+ # 1以上を消す
159
+ norm_batch_distance_map = (-F.relu(-norm_batch_distance_map+1)+1)
160
+ keypoint_maps = norm_batch_distance_map.detach().cpu()
161
+
162
+ alpha = 0.8 # Transparency factor for the heatmap overlay
163
+ blended_tensors = (1 - alpha) * imgs + alpha * torch.cat(
164
+ (norm_batch_distance_map.unsqueeze(1), torch.zeros(batch_size,2,size,size,device=device)),
165
+ dim=1
166
+ )
167
+ blended_tensors = norm_img(blended_tensors).detach().cpu()
168
+
169
+ return keypoint_maps, blended_tensors