File size: 2,493 Bytes
b479d0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320097c
 
 
b479d0e
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr 
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import ViTMAEForPreTraining, ViTFeatureExtractor
from PIL import Image
import uuid

feature_extractor = ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base")
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")

imagenet_mean = np.array(feature_extractor.image_mean)
imagenet_std = np.array(feature_extractor.image_std)

def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    unique_id = str(uuid.uuid4())
    plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
    plt.axis('off')
    plt.savefig(f"{unique_id}.png", bbox_inches='tight', pad_inches=0)

    return f"{unique_id}.png"

def visualize(image):
    pixel_values = feature_extractor(image, return_tensors="pt").pixel_values
    # forward pass
    outputs = model(pixel_values)
    y = model.unpatchify(outputs.logits)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()
    
    # visualize the mask
    mask = outputs.mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    
    x = torch.einsum('nchw->nhwc', pixel_values)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    gallery_labels = ["Original Image", "Masked Image", "Reconstruction", "Reconstruction with Patches"]
    gallery_out = [show_image(out) for out in [x[0], im_masked[0], y[0], im_paste[0]]]

    return [(k,v) for k,v in zip(gallery_out, gallery_labels)]

  


with gr.Blocks() as demo:
  gr.Markdown("## ViTMAE Demo")
  gr.Markdown("**ViTMAE is an architecture that combine masked autoencoder and Vision Transformer (ViT) for self-supervised pre-training.**")
  gr.Markdown("**By pre-training a ViT to reconstruct pixel values for masked patches, one can get results after fine-tuning that outperform supervised pre-training.**")
  gr.Markdown("**This application demonstrates the reconstruction. To start, simply upload an image.**")
  with gr.Row():

    input_img = gr.Image()
    output = gr.Gallery()
  
  input_img.change(visualize, inputs=input_img, outputs=output)
  
  gr.Examples([["./cat.png"]], inputs=input_img, outputs=output, fn=visualize)

demo.launch(debug=True)