Spaces:
Runtime error
Runtime error
import torch | |
from model import MaskedAutoencoderViT, mae_vit_base_patch16 | |
import numpy as np | |
from PIL import Image | |
import torch.nn.functional as F | |
from einops import rearrange | |
from transformers import AutoTokenizer | |
from collections import OrderedDict | |
from huggingface_hub import hf_hub_download | |
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', ) | |
ckpt = torch.load(hf_hub_download('tennant/MUG', 'mae_bert_vit_b_cc3m.pth'), map_location='cpu') | |
new_dict = OrderedDict() | |
for k, v in ckpt.items(): | |
k = k[len('image_encoder.model.'):] | |
new_dict.update({k: v}) | |
model = mae_vit_base_patch16(uni_dim=768, less_u=True) | |
model.load_state_dict(new_dict) | |
if torch.cuda.is_available(): | |
model.cuda() | |
model.eval() | |
def visual_recon(x, model, mask_ratio=0.75): | |
target = model.patchify(x) | |
mean = target.mean(dim=-1, keepdim=True) | |
var = target.var(dim=-1, keepdim=True) | |
latent, mask, ids_restore, _ = model.forward_encoder(x, mask_ratio=mask_ratio) | |
y, _ = model.forward_decoder(latent, ids_restore) | |
y = y * (var + 1.e-6)**.5 + mean | |
y = model.unpatchify(y) | |
y = torch.einsum('nchw->nhwc', y) | |
mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3) | |
mask = model.unpatchify(mask) # 1 is removing, 0 is keeping | |
mask = torch.einsum('nchw->nhwc', mask) | |
x = torch.einsum('nchw->nhwc', x) | |
return x * (1 - mask), x * (1 - mask) + y * mask, y, latent | |
def caption_next_word(latent, model, tokenizer, prefix='a photo of a'): | |
assert latent.shape[0] == 1, 'can only caption one image at a time' | |
x_l = torch.tensor(tokenizer([prefix, ])['input_ids'])[:, :-1] | |
seq = x_l.shape[1] | |
if torch.cuda.is_available(): | |
x_l = x_l.cuda() | |
cls_mask = rearrange(x_l != 0, 'b j -> b 1 j') | |
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) | |
x_l = model.embed_text(x_l) | |
for cross_attn1, cross_attn2 in model.multimodal_layers: | |
x_l = cross_attn1(x_l, latent) | |
x_l = cross_attn2(x_l, latent) | |
pred = model.to_logits(x_l) | |
pred[:, :, 103] = -100 | |
pred[:, :, 101] = -100 | |
pred[:, :, 100] = -100 | |
pred[:, :, 0] = -100 | |
next_word = pred.argmax(dim=-1)[0, -1] | |
next_word = tokenizer.decode(next_word) | |
return next_word | |
def caption(max_len, latent, model, tokenizer, prefix='a photo of a'): | |
words = prefix.split() | |
while len(words) < max_len: | |
next_word = caption_next_word(latent, model, tokenizer, prefix=' '.join(words)) | |
words.append(next_word) | |
if next_word == '[SEP]': | |
break | |
return ' '.join(words) | |
def gr_caption(x, mask_ratio=0.75, max_len=20, prefix='a'): | |
imagenet_mean = np.array([0.485, 0.456, 0.406]) | |
imagenet_std = np.array([0.229, 0.224, 0.225]) | |
x = np.array(x) / 255. | |
x = x - imagenet_mean | |
x = x / imagenet_std | |
x = torch.tensor(x).float() | |
x = x.unsqueeze(0) | |
x = torch.einsum('nhwc->nchw', x) | |
if torch.cuda.is_available(): | |
x = x.cuda() | |
def unnorm_pix(img): | |
img = img.squeeze(0).cpu().detach().numpy() | |
img = img * imagenet_std + imagenet_mean | |
return np.clip(img, a_min=0., a_max=1.) | |
masked, masked_recon, recon, latent = visual_recon(x, model, mask_ratio=mask_ratio) | |
caption_from_model = caption(max_len, latent, model, tokenizer, prefix=prefix) | |
masked, masked_recon, recon = map(unnorm_pix, (masked, masked_recon, recon)) | |
return_img = np.concatenate([masked, masked_recon, recon], axis=1) | |
return return_img, caption_from_model | |
import gradio as gr | |
demo = gr.Interface(gr_caption, | |
inputs=[gr.Image(value='cat.jpeg', shape=(224, 224)), | |
gr.Number(value=0.75, label='mask ratio'), | |
gr.Number(value=20, label='max length'), | |
gr.Textbox(value='a photo of a', label='caption prefix')], | |
outputs=[gr.Image(shape=(224, 224 * 3)), | |
'text'], | |
# examples=[['cat.jpeg', 0.75, 20, 'a photo of a']], | |
) | |
demo.launch() | |