yeq6x commited on
Commit
01a01d7
1 Parent(s): 02ba63a

add description and examples

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. app.py +50 -17
  3. resources/examples/2488.jpg +0 -0
  4. resources/examples/2899.jpg +0 -0
  5. resources/trainB/0000.jpg +0 -0
  6. resources/trainB/0001.jpg +0 -0
  7. resources/trainB/0002.jpg +0 -0
  8. resources/trainB/0003.jpg +0 -0
  9. resources/trainB/0004.jpg +0 -0
  10. resources/trainB/0005.jpg +0 -0
  11. resources/trainB/0006.jpg +0 -0
  12. resources/trainB/0007.jpg +0 -0
  13. resources/trainB/0008.jpg +0 -0
  14. resources/trainB/0009.jpg +0 -0
  15. resources/trainB/0010.jpg +0 -0
  16. resources/trainB/0011.jpg +0 -0
  17. resources/trainB/0012.jpg +0 -0
  18. resources/trainB/0013.jpg +0 -0
  19. resources/trainB/0014.jpg +0 -0
  20. resources/trainB/0015.jpg +0 -0
  21. resources/trainB/0016.jpg +0 -0
  22. resources/trainB/0017.jpg +0 -0
  23. resources/trainB/0018.jpg +0 -0
  24. resources/trainB/0019.jpg +0 -0
  25. resources/trainB/0020.jpg +0 -0
  26. resources/trainB/0021.jpg +0 -0
  27. resources/trainB/0022.jpg +0 -0
  28. resources/trainB/0023.jpg +0 -0
  29. resources/trainB/0024.jpg +0 -0
  30. resources/trainB/0025.jpg +0 -0
  31. resources/trainB/0026.jpg +0 -0
  32. resources/trainB/0027.jpg +0 -0
  33. resources/trainB/0028.jpg +0 -0
  34. resources/trainB/0029.jpg +0 -0
  35. resources/trainB/0030.jpg +0 -0
  36. resources/trainB/0031.jpg +0 -0
  37. resources/trainB/0032.jpg +0 -0
  38. resources/trainB/0033.jpg +0 -0
  39. resources/trainB/0034.jpg +0 -0
  40. resources/trainB/0035.jpg +0 -0
  41. resources/trainB/0036.jpg +0 -0
  42. resources/trainB/0037.jpg +0 -0
  43. resources/trainB/0038.jpg +0 -0
  44. resources/trainB/0039.jpg +0 -0
  45. resources/trainB/0040.jpg +0 -0
  46. resources/trainB/0041.jpg +0 -0
  47. resources/trainB/0042.jpg +0 -0
  48. resources/trainB/0043.jpg +0 -0
  49. resources/trainB/0044.jpg +0 -0
  50. resources/trainB/0045.jpg +0 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv
2
+ __pycache__/
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- import spaces
3
  import torch
4
  import torch.nn.functional as F
5
  from torch.utils.data import DataLoader
@@ -14,17 +14,17 @@ from io import BytesIO
14
 
15
  # モデルとデータの読み込み
16
  def load_model():
17
- model_path = "checkpoints/ae_model_tf_2024-03-05_00-35-21.pth"
18
- feature_dim = 32
19
  model = AutoencoderModule(feature_dim=feature_dim)
20
  state_dict = torch.load(model_path)
21
 
22
- # state_dict のキーを修正
23
- new_state_dict = {}
24
- for key in state_dict:
25
- new_key = "model." + key
26
- new_state_dict[new_key] = state_dict[key]
27
- model.load_state_dict(new_state_dict)
28
  model.eval()
29
 
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -58,7 +58,9 @@ x = load_data(device)
58
 
59
  # アップロード画像の前処理
60
  def preprocess_uploaded_image(uploaded_image, image_size):
61
- uploaded_image = Image.fromarray(uploaded_image)
 
 
62
  uploaded_image = uploaded_image.convert("RGB")
63
  uploaded_image = uploaded_image.resize((image_size, image_size))
64
  uploaded_image = np.array(uploaded_image).transpose(2, 0, 1) / 255.0
@@ -66,8 +68,17 @@ def preprocess_uploaded_image(uploaded_image, image_size):
66
  return uploaded_image
67
 
68
  # ヒートマップの生成関数
69
- @spaces.GPU
70
  def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
 
 
 
 
 
 
 
 
 
71
  with torch.no_grad():
72
  dec5, _ = model(x)
73
  img = x
@@ -101,7 +112,7 @@ def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
101
  source_map = norm_batch_distance_map[source_num]
102
  target_map = norm_batch_distance_map[target_num]
103
 
104
- alpha = 0.8
105
  blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num] / norm_batch_distance_map[source_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
106
  blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num] / norm_batch_distance_map[target_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
107
 
@@ -154,6 +165,7 @@ async () => {
154
  console.log(files);
155
  if (files && files.length > 0) {
156
  console.log("File selected");
 
157
  document.querySelector("#crop_view").style.display = "block";
158
  document.querySelector("#crop_button").style.display = "block";
159
  const url = URL.createObjectURL(files[0]);
@@ -183,6 +195,7 @@ async () => {
183
 
184
  document.getElementById("crop_view").style.display = "none";
185
  document.getElementById("crop_button").style.display = "none";
 
186
 
187
  cropper.destroy();
188
  }
@@ -194,6 +207,16 @@ async () => {
194
  """
195
 
196
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
197
  with gr.Row():
198
  with gr.Column():
199
  source_num = gr.Slider(0, batch_size - 1, step=1, label="Source Image Index")
@@ -202,20 +225,30 @@ with gr.Blocks() as demo:
202
 
203
  # GradioのFileコンポーネントでファイル選択ボタンを追加
204
  gr.HTML('<input type="file" id="input_file" style="display:none;">')
205
- input_file_button = gr.Button("画像を選択", elem_id="input_file_button")
 
206
  # 画像を表示するためのHTML画像タグをGradioで表示
207
  gr.HTML('<img id="crop_view" style="max-width:100%;">')
208
  # Gradioのボタンコンポーネントを追加し、IDを付与
209
- crop_button = gr.Button("クロップ", elem_id="crop_button", variant="primary")
210
  # クロップされた画像データのテキストボックス(Base64データ)
211
  cropped_image_data = gr.Textbox(visible=False, elem_id="cropped_image_data")
212
- input_image = gr.Image(label="Cropped Image", interactive=False)
213
  # cropped_image_dataが更新されたらprocess_imageを呼び出す
214
  cropped_image_data.change(process_image, inputs=cropped_image_data, outputs=input_image)
215
 
 
 
 
 
 
 
 
 
 
216
  with gr.Column():
217
  output_plot = gr.Plot()
218
 
 
219
  # Gradioインターフェースの代わり
220
  source_num.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
221
  x_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
@@ -223,7 +256,7 @@ with gr.Blocks() as demo:
223
  input_image.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
224
 
225
  # JavaScriptコードをロード
226
- demo.load(None, None, None, js=scripts)
227
-
228
  demo.launch()
229
 
 
1
  import gradio as gr
2
+ # import spaces
3
  import torch
4
  import torch.nn.functional as F
5
  from torch.utils.data import DataLoader
 
14
 
15
  # モデルとデータの読み込み
16
  def load_model():
17
+ model_path = "checkpoints/autoencoder-epoch=49-train_loss=1.01.ckpt"
18
+ feature_dim = 64
19
  model = AutoencoderModule(feature_dim=feature_dim)
20
  state_dict = torch.load(model_path)
21
 
22
+ # # state_dict のキーを修正
23
+ # new_state_dict = {}
24
+ # for key in state_dict:
25
+ # new_key = "model." + key
26
+ # new_state_dict[new_key] = state_dict[key]
27
+ model.load_state_dict(state_dict['state_dict'])
28
  model.eval()
29
 
30
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
58
 
59
  # アップロード画像の前処理
60
  def preprocess_uploaded_image(uploaded_image, image_size):
61
+ # ndarrayの場合はPILイメージに変換
62
+ if type(uploaded_image) == np.ndarray:
63
+ uploaded_image = Image.fromarray(uploaded_image)
64
  uploaded_image = uploaded_image.convert("RGB")
65
  uploaded_image = uploaded_image.resize((image_size, image_size))
66
  uploaded_image = np.array(uploaded_image).transpose(2, 0, 1) / 255.0
 
68
  return uploaded_image
69
 
70
  # ヒートマップの生成関数
71
+ # @spaces.GPU
72
  def get_heatmaps(source_num, x_coords, y_coords, uploaded_image):
73
+ if type(uploaded_image) == str:
74
+ uploaded_image = Image.open(uploaded_image)
75
+ if type(source_num) == str:
76
+ source_num = int(source_num)
77
+ if type(x_coords) == str:
78
+ x_coords = int(x_coords)
79
+ if type(y_coords) == str:
80
+ y_coords = int(y_coords)
81
+
82
  with torch.no_grad():
83
  dec5, _ = model(x)
84
  img = x
 
112
  source_map = norm_batch_distance_map[source_num]
113
  target_map = norm_batch_distance_map[target_num]
114
 
115
+ alpha = 0.7
116
  blended_source = (1 - alpha) * img[source_num] + alpha * torch.cat(((norm_batch_distance_map[source_num] / norm_batch_distance_map[source_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
117
  blended_target = (1 - alpha) * img[target_num] + alpha * torch.cat(((norm_batch_distance_map[target_num] / norm_batch_distance_map[target_num].max()).unsqueeze(0), torch.zeros(2, image_size, image_size, device=device)))
118
 
 
165
  console.log(files);
166
  if (files && files.length > 0) {
167
  console.log("File selected");
168
+ document.querySelector("#input_file_button").style.display = "none";
169
  document.querySelector("#crop_view").style.display = "block";
170
  document.querySelector("#crop_button").style.display = "block";
171
  const url = URL.createObjectURL(files[0]);
 
195
 
196
  document.getElementById("crop_view").style.display = "none";
197
  document.getElementById("crop_button").style.display = "none";
198
+ document.querySelector("#input_file_button").style.display = "block";
199
 
200
  cropper.destroy();
201
  }
 
207
  """
208
 
209
  with gr.Blocks() as demo:
210
+ # title
211
+ gr.Markdown("# TripletGeoEncoder Feature Map Visualization")
212
+ # description
213
+ gr.Markdown("This demo visualizes the feature maps of a TripletGeoEncoder trained on the CelebA dataset using self-supervised learning without annotations from only 1000 images. "
214
+ "The feature maps are visualized as heatmaps, where the source map shows the distance of each pixel in the source image to the selected pixel, and the target map shows the distance of each pixel in the target image to the selected pixel. "
215
+
216
+ "The blended source and target images show the source and target images with the source and target maps overlaid, respectively. "
217
+
218
+ "For further information, please contact me on X (formerly Twitter): @Yeq6X.")
219
+
220
  with gr.Row():
221
  with gr.Column():
222
  source_num = gr.Slider(0, batch_size - 1, step=1, label="Source Image Index")
 
225
 
226
  # GradioのFileコンポーネントでファイル選択ボタンを追加
227
  gr.HTML('<input type="file" id="input_file" style="display:none;">')
228
+ input_file_button = gr.Button("Upload Image and Crop", elem_id="input_file_button", variant="primary")
229
+ crop_button = gr.Button("Crop", elem_id="crop_button", variant="primary")
230
  # 画像を表示するためのHTML画像タグをGradioで表示
231
  gr.HTML('<img id="crop_view" style="max-width:100%;">')
232
  # Gradioのボタンコンポーネントを追加し、IDを付与
 
233
  # クロップされた画像データのテキストボックス(Base64データ)
234
  cropped_image_data = gr.Textbox(visible=False, elem_id="cropped_image_data")
235
+ input_image = gr.Image(label="Cropped Image", elem_id="input_image")
236
  # cropped_image_dataが更新されたらprocess_imageを呼び出す
237
  cropped_image_data.change(process_image, inputs=cropped_image_data, outputs=input_image)
238
 
239
+ # examples
240
+ gr.Markdown("# Examples")
241
+ gr.Examples(
242
+ examples=[
243
+ ["0", "50", "50", "resources/examples/2488.jpg"],
244
+ ["0", "50", "50", "resources/examples/2899.jpg"]
245
+ ],
246
+ inputs=[source_num, x_coords, y_coords, input_image],
247
+ )
248
  with gr.Column():
249
  output_plot = gr.Plot()
250
 
251
+
252
  # Gradioインターフェースの代わり
253
  source_num.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
254
  x_coords.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
 
256
  input_image.change(get_heatmaps, inputs=[source_num, x_coords, y_coords, input_image], outputs=output_plot)
257
 
258
  # JavaScriptコードをロード
259
+ demo.load(None, None, None, js=scripts)
260
+
261
  demo.launch()
262
 
resources/examples/2488.jpg ADDED
resources/examples/2899.jpg ADDED
resources/trainB/0000.jpg ADDED
resources/trainB/0001.jpg ADDED
resources/trainB/0002.jpg ADDED
resources/trainB/0003.jpg ADDED
resources/trainB/0004.jpg ADDED
resources/trainB/0005.jpg ADDED
resources/trainB/0006.jpg ADDED
resources/trainB/0007.jpg ADDED
resources/trainB/0008.jpg ADDED
resources/trainB/0009.jpg ADDED
resources/trainB/0010.jpg ADDED
resources/trainB/0011.jpg ADDED
resources/trainB/0012.jpg ADDED
resources/trainB/0013.jpg ADDED
resources/trainB/0014.jpg ADDED
resources/trainB/0015.jpg ADDED
resources/trainB/0016.jpg ADDED
resources/trainB/0017.jpg ADDED
resources/trainB/0018.jpg ADDED
resources/trainB/0019.jpg ADDED
resources/trainB/0020.jpg ADDED
resources/trainB/0021.jpg ADDED
resources/trainB/0022.jpg ADDED
resources/trainB/0023.jpg ADDED
resources/trainB/0024.jpg ADDED
resources/trainB/0025.jpg ADDED
resources/trainB/0026.jpg ADDED
resources/trainB/0027.jpg ADDED
resources/trainB/0028.jpg ADDED
resources/trainB/0029.jpg ADDED
resources/trainB/0030.jpg ADDED
resources/trainB/0031.jpg ADDED
resources/trainB/0032.jpg ADDED
resources/trainB/0033.jpg ADDED
resources/trainB/0034.jpg ADDED
resources/trainB/0035.jpg ADDED
resources/trainB/0036.jpg ADDED
resources/trainB/0037.jpg ADDED
resources/trainB/0038.jpg ADDED
resources/trainB/0039.jpg ADDED
resources/trainB/0040.jpg ADDED
resources/trainB/0041.jpg ADDED
resources/trainB/0042.jpg ADDED
resources/trainB/0043.jpg ADDED
resources/trainB/0044.jpg ADDED
resources/trainB/0045.jpg ADDED