serdaryildiz's picture
Upload 24 files
af06dba verified
raw
history blame
1.6 kB
import argparse
import glob
import os
import cv2
import numpy
import torch
from PIL import Image
from Model import TRCaptionNet, clip_transform
def demo(opt):
preprocess = clip_transform(224)
model = TRCaptionNet({
"max_length": 35,
"clip": "ViT-L/14",
"bert": "dbmdz/bert-base-turkish-cased",
"proj": True,
"proj_num_head": 16
})
device = torch.device(opt.device)
model.load_state_dict(torch.load(opt.model_ckpt, map_location=device)["model"], strict=True)
model = model.to(device)
model.eval()
image_paths = glob.glob(os.path.join(opt.input_dir, '*.jpg'))
for image_path in sorted(image_paths):
img_name = image_path.split('/')[-1]
img0 = Image.open(image_path)
batch = preprocess(img0).unsqueeze(0).to(device)
caption = model.generate(batch, min_length=11, repetition_penalty=1.6)[0]
print(f"{img_name} :", caption)
orj_img = numpy.array(img0)[:, :, ::-1]
h, w, _ = orj_img.shape
new_h = 800
new_w = int(new_h * (w / h))
orj_img = cv2.resize(orj_img, (new_w, new_h))
cv2.imshow("image", orj_img)
cv2.waitKey(0)
return
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Turkish-Image-Captioning!')
parser.add_argument('--model-ckpt', type=str, default='./checkpoints/TRCaptionNet_L14_berturk.pth')
parser.add_argument('--input-dir', type=str, default='./images/')
parser.add_argument('--device', type=str, default='cuda:0')
args = parser.parse_args()
demo(args)