Spaces:
Sleeping
Sleeping
gpu
Browse files- app.py +16 -37
- resources/mean_vector_list_ae_model_tf_2024-03-05_00-35-21.pth.npy +3 -0
- resources/mean_vector_list_autoencoder-epoch=09-train_loss=1.00.ckpt.npy +3 -0
- resources/mean_vector_list_autoencoder-epoch=29-train_loss=1.01.ckpt.npy +3 -0
- resources/mean_vector_list_autoencoder-epoch=49-train_loss=1.01.ckpt.npy +3 -0
- utils.py +1 -1
app.py
CHANGED
@@ -8,6 +8,7 @@ import numpy as np
|
|
8 |
from PIL import Image
|
9 |
import base64
|
10 |
from io import BytesIO
|
|
|
11 |
|
12 |
import dataset
|
13 |
from dataset import MyDataset, ImageKeypointDataset, load_filenames, load_keypoints
|
@@ -92,8 +93,7 @@ model_index = 0
|
|
92 |
|
93 |
# ヒートマップの生成関数
|
94 |
@spaces.GPU
|
95 |
-
def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
96 |
-
global model_index, mean_vector_list
|
97 |
if type(uploaded_image) == str:
|
98 |
uploaded_image = Image.open(uploaded_image)
|
99 |
if type(source_num) == str:
|
@@ -102,6 +102,13 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
|
102 |
x_coords = int(x_coords)
|
103 |
if type(y_coords) == str:
|
104 |
y_coords = int(y_coords)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
dec5, _ = models[model_index](x)
|
107 |
feature_map = dec5
|
@@ -138,24 +145,6 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
|
|
138 |
plt.close(fig)
|
139 |
return fig
|
140 |
|
141 |
-
@spaces.GPU
|
142 |
-
def setup(model_info, input_image=None):
|
143 |
-
global model_index, mean_vector_list
|
144 |
-
# str -> dictに変換
|
145 |
-
if type(model_info) == str:
|
146 |
-
model_info = eval(model_info)
|
147 |
-
|
148 |
-
model_index = models_info.index(model_info)
|
149 |
-
|
150 |
-
feature_map, _ = models[model_index](test_imgs)
|
151 |
-
mean_vector_list = utils.get_mean_vector(feature_map, points)
|
152 |
-
|
153 |
-
if input_image is not None:
|
154 |
-
fig = get_heatmaps(0, image_size // 2, image_size // 2, input_image)
|
155 |
-
return fig
|
156 |
-
|
157 |
-
print("setup done.")
|
158 |
-
|
159 |
with gr.Blocks() as demo:
|
160 |
# title
|
161 |
gr.Markdown("# TripletGeoEncoder Feature Map Visualization")
|
@@ -168,24 +157,17 @@ with gr.Blocks() as demo:
|
|
168 |
"For further information, please contact me on X (formerly Twitter): @Yeq6X.")
|
169 |
|
170 |
gr.Markdown("## Heatmap Visualization")
|
171 |
-
|
|
|
|
|
|
|
|
|
172 |
input_image = gr.ImageEditor(label="Cropped Image", elem_id="input_image", crop_size=(112, 112), show_fullscreen_button=True)
|
173 |
-
output_plot = gr.Plot(value=None, elem_id="output_plot", show_label=False)
|
174 |
-
with gr.Row():
|
175 |
-
with gr.Column():
|
176 |
-
with gr.Row():
|
177 |
-
model_name = gr.Dropdown(
|
178 |
-
choices=[str(model_info) for model_info in models_info],
|
179 |
-
container=False
|
180 |
-
)
|
181 |
-
load_button = gr.Button("Load Model")
|
182 |
-
load_button.click(setup, inputs=[model_name, input_image], outputs=[output_plot])
|
183 |
-
with gr.Row():
|
184 |
-
pass
|
185 |
-
|
186 |
inference = gr.Interface(
|
187 |
get_heatmaps,
|
188 |
inputs=[
|
|
|
189 |
gr.Slider(0, batch_size - 1, step=1, label="Source Image Index"),
|
190 |
gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="X Coordinate"),
|
191 |
gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="Y Coordinate"),
|
@@ -205,8 +187,5 @@ with gr.Blocks() as demo:
|
|
205 |
inputs=[input_image],
|
206 |
)
|
207 |
|
208 |
-
setup(models_info[0])
|
209 |
-
print(mean_vector_list)
|
210 |
-
|
211 |
demo.launch()
|
212 |
|
|
|
8 |
from PIL import Image
|
9 |
import base64
|
10 |
from io import BytesIO
|
11 |
+
import os
|
12 |
|
13 |
import dataset
|
14 |
from dataset import MyDataset, ImageKeypointDataset, load_filenames, load_keypoints
|
|
|
93 |
|
94 |
# ヒートマップの生成関数
|
95 |
@spaces.GPU
|
96 |
+
def get_heatmaps(model_info, source_num, x_coords, y_coords, uploaded_image):
|
|
|
97 |
if type(uploaded_image) == str:
|
98 |
uploaded_image = Image.open(uploaded_image)
|
99 |
if type(source_num) == str:
|
|
|
102 |
x_coords = int(x_coords)
|
103 |
if type(y_coords) == str:
|
104 |
y_coords = int(y_coords)
|
105 |
+
|
106 |
+
if type(model_info) == str:
|
107 |
+
model_info = eval(model_info)
|
108 |
+
model_index = models_info.index(model_info)
|
109 |
+
|
110 |
+
mean_vector_list = np.load(f"resources/mean_vector_list_{model_info['name']}.npy", allow_pickle=True)
|
111 |
+
mean_vector_list = torch.tensor(mean_vector_list).to(device)
|
112 |
|
113 |
dec5, _ = models[model_index](x)
|
114 |
feature_map = dec5
|
|
|
145 |
plt.close(fig)
|
146 |
return fig
|
147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
with gr.Blocks() as demo:
|
149 |
# title
|
150 |
gr.Markdown("# TripletGeoEncoder Feature Map Visualization")
|
|
|
157 |
"For further information, please contact me on X (formerly Twitter): @Yeq6X.")
|
158 |
|
159 |
gr.Markdown("## Heatmap Visualization")
|
160 |
+
|
161 |
+
model_info = gr.Dropdown(
|
162 |
+
choices=[str(model_info) for model_info in models_info],
|
163 |
+
container=False
|
164 |
+
)
|
165 |
input_image = gr.ImageEditor(label="Cropped Image", elem_id="input_image", crop_size=(112, 112), show_fullscreen_button=True)
|
166 |
+
output_plot = gr.Plot(value=None, elem_id="output_plot", show_label=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
inference = gr.Interface(
|
168 |
get_heatmaps,
|
169 |
inputs=[
|
170 |
+
model_info,
|
171 |
gr.Slider(0, batch_size - 1, step=1, label="Source Image Index"),
|
172 |
gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="X Coordinate"),
|
173 |
gr.Slider(0, image_size - 1, step=1, value=image_size // 2, label="Y Coordinate"),
|
|
|
187 |
inputs=[input_image],
|
188 |
)
|
189 |
|
|
|
|
|
|
|
190 |
demo.launch()
|
191 |
|
resources/mean_vector_list_ae_model_tf_2024-03-05_00-35-21.pth.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1a8f81b924edda413139a5743408c6f38ebd7930b5d39cc98a6b4dd49bd42dae
|
3 |
+
size 3328
|
resources/mean_vector_list_autoencoder-epoch=09-train_loss=1.00.ckpt.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21c0d7cd81ee6e7fac9f0333209daec9e74a7c1e72e358b7732e8ecb3efea5f2
|
3 |
+
size 6528
|
resources/mean_vector_list_autoencoder-epoch=29-train_loss=1.01.ckpt.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f9edb6ad2b6a0b9121905ea04ac1a39618d709329045c9cf00673f6281fc412c
|
3 |
+
size 6528
|
resources/mean_vector_list_autoencoder-epoch=49-train_loss=1.01.ckpt.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4141deceac789a26c3d6cf02f20068e8caf37fb14b8af3f63ae365b32462d76f
|
3 |
+
size 6528
|
utils.py
CHANGED
@@ -132,7 +132,7 @@ def get_mean_vector(feature_map, points):
|
|
132 |
x_coords, y_coords = torch.round(points[:,i].t()).to(torch.long)
|
133 |
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
|
134 |
# mean_vector = vectors[0:10].mean(0) # 10個の特徴マップの平均ベクトルを取得
|
135 |
-
mean_vector = vectors.mean(0)
|
136 |
mean_vector_list.append(mean_vector)
|
137 |
return mean_vector_list
|
138 |
|
|
|
132 |
x_coords, y_coords = torch.round(points[:,i].t()).to(torch.long)
|
133 |
vectors = feature_map[torch.arange(feature_map.size(0)), :, y_coords, x_coords] # 1次元ベクトルに合わせてサイズを調整
|
134 |
# mean_vector = vectors[0:10].mean(0) # 10個の特徴マップの平均ベクトルを取得
|
135 |
+
mean_vector = vectors.mean(0).detach().cpu().numpy()
|
136 |
mean_vector_list.append(mean_vector)
|
137 |
return mean_vector_list
|
138 |
|