turhancan97 commited on
Commit
98a1feb
·
1 Parent(s): ae13cd7

chore: Add top 75 PCA mode to image reconstruction interface

Browse files
Files changed (2) hide show
  1. app.py +74 -8
  2. model/top_75/vit-t-mae-pretrain.pt +3 -0
app.py CHANGED
@@ -12,6 +12,7 @@ from model import MAE_ViT, MAE_Encoder, MAE_Decoder, MAE_Encoder_FeatureExtracto
12
 
13
  path_1 = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']]
14
  path_2 = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']]
 
15
  device = torch.device("cpu")
16
 
17
  model_name = "model/no_mode/vit-t-mae-pretrain.pt"
@@ -20,9 +21,14 @@ model_no_mode.eval()
20
  model_no_mode.to(device)
21
 
22
  model_name = "model/bottom_25/vit-t-mae-pretrain.pt"
23
- model_pca_mode = torch.load(model_name, map_location='cpu')
24
- model_pca_mode.eval()
25
- model_pca_mode.to(device)
 
 
 
 
 
26
 
27
  transform = v2.Compose([
28
  v2.Resize((96, 96)),
@@ -87,12 +93,54 @@ def visualize_single_image_no_mode(image_path):
87
 
88
  return np.array(plt.imread("output.png"))
89
 
90
- def visualize_single_image_pca_mode(image_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  img = load_image(image_path, transform).to(device)
92
 
93
  # Run inference
94
  with torch.no_grad():
95
- predicted_img, mask = model_pca_mode(img)
96
 
97
  # Convert the tensor back to a displayable image
98
  # masked image
@@ -145,6 +193,15 @@ outputs_image_2 = [
145
  gr.components.Image(type="numpy", label="Output Image"),
146
  ]
147
 
 
 
 
 
 
 
 
 
 
148
  inference_no_mode = gr.Interface(
149
  fn=visualize_single_image_no_mode,
150
  inputs=inputs_image_1,
@@ -155,8 +212,8 @@ inference_no_mode = gr.Interface(
155
  description="This is a demo of the MAE-ViT model for image reconstruction.",
156
  )
157
 
158
- inference_pca_mode = gr.Interface(
159
- fn=visualize_single_image_pca_mode,
160
  inputs=inputs_image_2,
161
  outputs=outputs_image_2,
162
  examples=path_2,
@@ -164,7 +221,16 @@ inference_pca_mode = gr.Interface(
164
  description="This is a demo of the MAE-ViT model for image reconstruction.",
165
  )
166
 
 
 
 
 
 
 
 
 
 
167
  gr.TabbedInterface(
168
- [inference_no_mode, inference_pca_mode],
169
  tab_names=['Normal Mode', 'PCA Mode']
170
  ).queue().launch()
 
12
 
13
  path_1 = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']]
14
  path_2 = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']]
15
+ path_3 = [['images/cat.jpg'], ['images/dog.jpg'], ['images/horse.jpg'], ['images/airplane.jpg'], ['images/truck.jpg']]
16
  device = torch.device("cpu")
17
 
18
  model_name = "model/no_mode/vit-t-mae-pretrain.pt"
 
21
  model_no_mode.to(device)
22
 
23
  model_name = "model/bottom_25/vit-t-mae-pretrain.pt"
24
+ model_pca_mode_bottom = torch.load(model_name, map_location='cpu')
25
+ model_pca_mode_bottom.eval()
26
+ model_pca_mode_bottom.to(device)
27
+
28
+ model_name = "model/top_75/vit-t-mae-pretrain.pt"
29
+ model_pca_mode_top = torch.load(model_name, map_location='cpu')
30
+ model_pca_mode_top.eval()
31
+ model_pca_mode_top.to(device)
32
 
33
  transform = v2.Compose([
34
  v2.Resize((96, 96)),
 
93
 
94
  return np.array(plt.imread("output.png"))
95
 
96
+ def visualize_single_image_pca_mode_bottom(image_path):
97
+ img = load_image(image_path, transform).to(device)
98
+
99
+ # Run inference
100
+ with torch.no_grad():
101
+ predicted_img, mask = model_pca_mode_bottom(img)
102
+
103
+ # Convert the tensor back to a displayable image
104
+ # masked image
105
+ im_masked = img * (1 - mask)
106
+
107
+ # MAE reconstruction pasted with visible patches
108
+ im_paste = img * (1 - mask) + predicted_img * mask
109
+
110
+ # remove the batch dimension
111
+ img = img[0]
112
+ im_masked = im_masked[0]
113
+ predicted_img = predicted_img[0]
114
+ im_paste = im_paste[0]
115
+
116
+ # make the plt figure larger
117
+ plt.figure(figsize=(18, 8))
118
+
119
+ plt.subplot(1, 4, 1)
120
+ show_image(img, "original")
121
+
122
+ plt.subplot(1, 4, 2)
123
+ show_image(im_masked, "masked")
124
+
125
+ plt.subplot(1, 4, 3)
126
+ show_image(predicted_img, "reconstruction")
127
+
128
+ plt.subplot(1, 4, 4)
129
+ show_image(im_paste, "reconstruction + visible")
130
+
131
+ plt.tight_layout()
132
+
133
+ # convert the plt figure to a numpy array
134
+ plt.savefig("output.png")
135
+
136
+ return np.array(plt.imread("output.png"))
137
+
138
+ def visualize_single_image_pca_mode_top(image_path):
139
  img = load_image(image_path, transform).to(device)
140
 
141
  # Run inference
142
  with torch.no_grad():
143
+ predicted_img, mask = model_pca_mode_top(img)
144
 
145
  # Convert the tensor back to a displayable image
146
  # masked image
 
193
  gr.components.Image(type="numpy", label="Output Image"),
194
  ]
195
 
196
+ inputs_image_3 = [
197
+ gr.components.Image(type="filepath", label="Input Image"),
198
+ ]
199
+
200
+ outputs_image_3 = [
201
+ gr.components.Image(type="numpy", label="Output Image"),
202
+ ]
203
+
204
+
205
  inference_no_mode = gr.Interface(
206
  fn=visualize_single_image_no_mode,
207
  inputs=inputs_image_1,
 
212
  description="This is a demo of the MAE-ViT model for image reconstruction.",
213
  )
214
 
215
+ inference_pca_mode_bottom = gr.Interface(
216
+ fn=visualize_single_image_pca_mode_bottom,
217
  inputs=inputs_image_2,
218
  outputs=outputs_image_2,
219
  examples=path_2,
 
221
  description="This is a demo of the MAE-ViT model for image reconstruction.",
222
  )
223
 
224
+ inference_pca_mode_top = gr.Interface(
225
+ fn=visualize_single_image_pca_mode_top,
226
+ inputs=inputs_image_3,
227
+ outputs=outputs_image_3,
228
+ examples=path_3,
229
+ title="MAE-ViT Image Reconstruction",
230
+ description="This is a demo of the MAE-ViT model for image reconstruction.",
231
+ )
232
+
233
  gr.TabbedInterface(
234
+ [inference_no_mode, inference_pca_mode_bottom, inference_pca_mode_top],
235
  tab_names=['Normal Mode', 'PCA Mode']
236
  ).queue().launch()
model/top_75/vit-t-mae-pretrain.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ecb391af126dc7ba24e85043e4a383782ed3e642977dcf8ad68c835891752ae
3
+ size 29121704