serdaryildiz's picture
Update demo.py
7c4ef2b verified
import argparse
import glob
import os
import cv2
import numpy
import torch
from PIL import Image
from Model import TRCaptionNet, clip_transform
def demo(opt):
preprocess = clip_transform(224)
model = TRCaptionNet({
"max_length": 35,
"clip": "ViT-L/14",
"bert": "dbmdz/bert-base-turkish-cased",
"proj": True,
"proj_num_head": 16
})
device = torch.device(opt.device)
model.load_state_dict(torch.load(opt.model_ckpt, map_location=device)["model"], strict=True)
model = model.to(device)
model.eval()
image_paths = glob.glob(os.path.join(opt.input_dir, '*.jpg'))
for image_path in sorted(image_paths):
img_name = image_path.split('/')[-1]
img0 = Image.open(image_path)
batch = preprocess(img0).unsqueeze(0).to(device)
caption = model.generate(batch, min_length=11, repetition_penalty=1.6)[0]
print(f"{img_name} :", caption)
orj_img = numpy.array(img0)[:, :, ::-1]
h, w, _ = orj_img.shape
new_h = 800
new_w = int(new_h * (w / h))
orj_img = cv2.resize(orj_img, (new_w, new_h))
cv2.imshow("image", orj_img)
cv2.waitKey(0)
return
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Turkish-Image-Captioning!')
parser.add_argument('--model-ckpt', type=str, default='./checkpoints/TRCaptionNet_L14_berturk.pth')
parser.add_argument('--input-dir', type=str, default='./images/')
parser.add_argument('--device', type=str, default='cuda:0')
args = parser.parse_args()
demo(args)