File size: 3,348 Bytes
58cfd1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import numpy as np
import torch
from transformers import GPT2TokenizerFast
from .models import VisionGPT2Model
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import matplotlib.pyplot as plt
from types import SimpleNamespace
import pathlib
from tkinter import filedialog
def download(url:str, filename:str)->pathlib.Path:
import functools
import shutil
import requests
from tqdm.auto import tqdm
r = requests.get(url, stream=True, allow_redirects=True)
if r.status_code != 200:
r.raise_for_status() # Will only raise for 4xx codes, so...
raise RuntimeError(f"Request to {url} returned status code {r.status_code}\n Please download the captioner.pt file manually from the link provided in the README.md file.")
file_size = int(r.headers.get('Content-Length', 0))
path = pathlib.Path(filename).expanduser().resolve()
path.parent.mkdir(parents=True, exist_ok=True)
desc = "(Unknown total file size)" if file_size == 0 else ""
r.raw.read = functools.partial(r.raw.read, decode_content=True) # Decompress if needed
with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw:
with path.open("wb") as f:
shutil.copyfileobj(r_raw, f)
return path
def main():
model_config = SimpleNamespace(
vocab_size = 50257, # GPT2 vocb size
embed_dim = 768, # dim same for both VIT and GPT2
num_heads = 12,
seq_len = 1024,
depth = 12,
attention_dropout = 0.1,
residual_dropout = 0.1,
mlp_ratio = 4,
mlp_dropout = 0.1,
emb_dropout = 0.1,
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VisionGPT2Model(model_config).to(device)
try:
sd = torch.load("captioner.pt", map_location=device)
except:
print("Model not found. Downloading Model ")
url = "https://drive.usercontent.google.com/download?id=1X51wAI7Bsnrhd2Pa4WUoHIXvvhIcRH7Y&export=download&authuser=0&confirm=t&uuid=ae5c4861-4411-4f81-88cd-66ea30b6fe2b&at=APZUnTWodeDt1upcQVMej2TDcADs%3A1722666079498"
path = download(url, "captioner.pt")
sd = torch.load(path, map_location=device)
model.load_state_dict(sd)
model.eval()
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tfms = A.Compose([
A.Resize(224, 224),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5],always_apply=True),
ToTensorV2()
])
test_img:str = filedialog.askopenfilename(title = "Select an image",
filetypes = (("jpeg files","*.jpg"),("png files",'*.png'),("all files","*.*")))
im = Image.open(test_img).convert("RGB")
det = True #generates deterministic results
temp = 1.0 #when det is true, temp has no effect
max_tokens = 50
image = np.array(im)
image:torch.Tensor = tfms(image=image)['image']
image = image.unsqueeze(0).to(device)
seq = torch.ones(1,1).to(device).long()*tokenizer.bos_token_id
caption = model.generate(image, seq, max_tokens, temp, det)
caption = tokenizer.decode(caption.numpy(), skip_special_tokens=True)
plt.imshow(im)
plt.title(f"Predicted : {caption}")
plt.axis('off')
plt.show()
if __name__ == "__main__":
main() |