yeq6x commited on
Commit
c6793c3
·
1 Parent(s): cc0703f
app.py CHANGED
@@ -8,6 +8,7 @@ import numpy as np
8
  from PIL import Image
9
  import base64
10
  from io import BytesIO
 
11
 
12
  import dataset
13
  from dataset import MyDataset, ImageKeypointDataset, load_filenames, load_keypoints
@@ -92,8 +93,7 @@ model_index = 0
92
 
93
  # ヒートマップの生成関数
94
  @spaces.GPU
95
- def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
96
- global model_index, mean_vector_list
97
  if type(uploaded_image) == str:
98
  uploaded_image = Image.open(uploaded_image)
99
  if type(source_num) == str:
@@ -102,6 +102,13 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
102
  x_coords = int(x_coords)
103
  if type(y_coords) == str:
104
  y_coords = int(y_coords)
 
 
 
 
 
 
 
105
 
106
  dec5, _ = models[model_index](x)
107
  feature_map = dec5
@@ -138,24 +145,6 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
138
  plt.close(fig)
139
  return fig
140
 
141
- @spaces.GPU
142
- def setup(model_info, input_image=None):
143
- global model_index, mean_vector_list
144
- # str -> dictに変換
145
- if type(model_info) == str:
146
- model_info = eval(model_info)
147
-
148
- model_index = models_info.index(model_info)
149
-
150
- feature_map, _ = models[model_index](test_imgs)
151
- mean_vector_list = utils.get_mean_vector(feature_map, points)
152
-
153
- if input_image is not None:
154
- fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
155
- return fig
156
-
157
- print("setup done.")
158
-
159
  with gr.Blocks() as demo:
160
  # title
161
  gr.Markdown("# TripletGeoEncoder Feature Map Visualization")
@@ -168,24 +157,17 @@ with gr.Blocks() as demo:
168
  "For further information, please contact me on X (formerly Twitter): @Yeq6X.")
169
 
170
  gr.Markdown("## Heatmap Visualization")
171
-
 
 
 
 
172
  input_image = gr.ImageEditor(label="Cropped Image", elem_id="input_image", crop_size=(112, 112), show_fullscreen_button=True)
173
- output_plot = gr.Plot(value=None, elem_id="output_plot", show_label=False)
174
- with gr.Row():
175
- with gr.Column():
176
- with gr.Row():
177
- model_name = gr.Dropdown(
178
- choices=[str(model_info) for model_info in models_info],
179
- container=False
180
- )
181
- load_button = gr.Button("Load Model")
182
- load_button.click(setup, inputs=[model_name, input_image], outputs=[output_plot])
183
- with gr.Row():
184
- pass
185
-
186
  inference = gr.Interface(
187
  get_heatmaps,
188
  inputs=[
 
189
  gr.Slider(0, batch_size - 1, step=1, label="Source Image Index"),
190
  gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="X Coordinate"),
191
  gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="Y Coordinate"),
@@ -205,8 +187,5 @@ with gr.Blocks() as demo:
205
  inputs=[input_image],
206
  )
207
 
208
- setup(models_info[0])
209
- print(mean_vector_list)
210
-
211
  demo.launch()
212
 
 
8
  from PIL import Image
9
  import base64
10
  from io import BytesIO
11
+ import os
12
 
13
  import dataset
14
  from dataset import MyDataset, ImageKeypointDataset, load_filenames, load_keypoints
 
93
 
94
  # ヒートマップの生成関数
95
  @spaces.GPU
96
+ def get_heatmaps(model_info, source_num, x_coords, y_coords, uploaded_image):
 
97
  if type(uploaded_image) == str:
98
  uploaded_image = Image.open(uploaded_image)
99
  if type(source_num) == str:
 
102
  x_coords = int(x_coords)
103
  if type(y_coords) == str:
104
  y_coords = int(y_coords)
105
+
106
+ if type(model_info) == str:
107
+ model_info = eval(model_info)
108
+ model_index = models_info.index(model_info)
109
+
110
+ mean_vector_list = np.load(f"resources/mean_vector_list_{model_info['name']}.npy", allow_pickle=True)
111
+ mean_vector_list = torch.tensor(mean_vector_list).to(device)
112
 
113
  dec5, _ = models[model_index](x)
114
  feature_map = dec5
 
145
  plt.close(fig)
146
  return fig
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  with gr.Blocks() as demo:
149
  # title
150
  gr.Markdown("# TripletGeoEncoder Feature Map Visualization")
 
157
  "For further information, please contact me on X (formerly Twitter): @Yeq6X.")
158
 
159
  gr.Markdown("## Heatmap Visualization")
160
+
161
+ model_info = gr.Dropdown(
162
+ choices=[str(model_info) for model_info in models_info],
163
+ container=False
164
+ )
165
  input_image = gr.ImageEditor(label="Cropped Image", elem_id="input_image", crop_size=(112, 112), show_fullscreen_button=True)
166
+ output_plot = gr.Plot(value=None, elem_id="output_plot", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
167
  inference = gr.Interface(
168
  get_heatmaps,
169
  inputs=[
170
+ model_info,
171
  gr.Slider(0, batch_size - 1, step=1, label="Source Image Index"),
172
  gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="X Coordinate"),
173
  gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="Y Coordinate"),
 
187
  inputs=[input_image],
188
  )
189
 
 
 
 
190
  demo.launch()
191
 
resources/mean_vector_list_ae_model_tf_2024-03-05_00-35-21.pth.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a8f81b924edda413139a5743408c6f38ebd7930b5d39cc98a6b4dd49bd42dae
3
+ size 3328
resources/mean_vector_list_autoencoder-epoch=09-train_loss=1.00.ckpt.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21c0d7cd81ee6e7fac9f0333209daec9e74a7c1e72e358b7732e8ecb3efea5f2
3
+ size 6528
resources/mean_vector_list_autoencoder-epoch=29-train_loss=1.01.ckpt.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f9edb6ad2b6a0b9121905ea04ac1a39618d709329045c9cf00673f6281fc412c
3
+ size 6528
resources/mean_vector_list_autoencoder-epoch=49-train_loss=1.01.ckpt.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4141deceac789a26c3d6cf02f20068e8caf37fb14b8af3f63ae365b32462d76f
3
+ size 6528
utils.py CHANGED
@@ -132,7 +132,7 @@ def get_mean_vector(feature_map, points):
132
  x_coords, y_coords = torch.round(points[:,i].t()).to(torch.long)
133
  vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
134
  # mean_vector = vectors[0:10].mean(0) # 10個の特徴マップの平均ベクトルを取得
135
- mean_vector = vectors.mean(0)
136
  mean_vector_list.append(mean_vector)
137
  return mean_vector_list
138
 
 
132
  x_coords, y_coords = torch.round(points[:,i].t()).to(torch.long)
133
  vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
134
  # mean_vector = vectors[0:10].mean(0) # 10個の特徴マップの平均ベクトルを取得
135
+ mean_vector = vectors.mean(0).detach().cpu().numpy()
136
  mean_vector_list.append(mean_vector)
137
  return mean_vector_list
138