import torch from model import MaskedAutoencoderViT, mae_vit_base_patch16 import numpy as np from PIL import Image import torch.nn.functional as F from einops import rearrange from transformers import AutoTokenizer from collections import OrderedDict from huggingface_hub import hf_hub_download tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', ) ckpt = torch.load(hf_hub_download('tennant/MUG', 'mae_bert_vit_b_cc3m.pth'), map_location='cpu') new_dict = OrderedDict() for k, v in ckpt.items(): k = k[len('image_encoder.model.'):] new_dict.update({k: v}) model = mae_vit_base_patch16(uni_dim=768, less_u=True) model.load_state_dict(new_dict) if torch.cuda.is_available(): model.cuda() model.eval() @torch.no_grad() def visual_recon(x, model): target = model.patchify(x) mean = target.mean(dim=-1, keepdim=True) var = target.var(dim=-1, keepdim=True) latent, mask, ids_restore, _ = model.forward_encoder(x, mask_ratio=0.75) y, _ = model.forward_decoder(latent, ids_restore) y = y * (var + 1.e-6)**.5 + mean y = model.unpatchify(y) y = torch.einsum('nchw->nhwc', y) mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3) # (N, H*W, p*p*3) mask = model.unpatchify(mask) # 1 is removing, 0 is keeping mask = torch.einsum('nchw->nhwc', mask) x = torch.einsum('nchw->nhwc', x) return x * (1 - mask), x * (1 - mask) + y * mask, y, latent @torch.no_grad() def caption_next_word(latent, model, tokenizer, prefix='a photo of a'): assert latent.shape[0] == 1, 'can only caption one image at a time' x_l = torch.tensor(tokenizer([prefix, ])['input_ids'])[:, :-1] seq = x_l.shape[1] if torch.cuda.is_available(): x_l = x_l.cuda() cls_mask = rearrange(x_l != 0, 'b j -> b 1 j') attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True) x_l = model.embed_text(x_l) for cross_attn1, cross_attn2 in model.multimodal_layers: x_l = cross_attn1(x_l, latent) x_l = cross_attn2(x_l, latent) pred = model.to_logits(x_l) pred[:, :, 103] = -100 pred[:, :, 101] = -100 pred[:, :, 100] = -100 pred[:, :, 0] = -100 next_word = pred.argmax(dim=-1)[0, -1] next_word = tokenizer.decode(next_word) return next_word def caption(max_len, latent, model, tokenizer, prefix='a photo of a'): words = prefix.split() while len(words) < max_len: next_word = caption_next_word(latent, model, tokenizer, prefix=' '.join(words)) words.append(next_word) if next_word == '[SEP]': break return ' '.join(words) def gr_caption(x): imagenet_mean = np.array([0.485, 0.456, 0.406]) imagenet_std = np.array([0.229, 0.224, 0.225]) x = np.array(x) / 255. x = x - imagenet_mean x = x / imagenet_std x = torch.tensor(x).float() x = x.unsqueeze(0) x = torch.einsum('nhwc->nchw', x) if torch.cuda.is_available(): x = x.cuda() def unnorm_pix(img): img = img.squeeze(0).cpu().detach().numpy() img = img * imagenet_std + imagenet_mean return np.clip(img, a_min=0., a_max=1.) masked, masked_recon, recon, latent = visual_recon(x, model) caption_from_model = caption(10, latent, model, tokenizer, ) masked, masked_recon, recon = map(unnorm_pix, (masked, masked_recon, recon)) return masked, masked_recon, recon, caption_from_model import gradio as gr demo = gr.Interface(gr_caption, inputs=[gr.Image(shape=(224, 224))], outputs=[gr.Image(shape=(224, 224)), gr.Image(shape=(224, 224)), gr.Image(shape=(224, 224)), 'text']) demo.launch()