File size: 2,896 Bytes
5b31094
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
 * Tag2Text
 * Written by Xinyu Huang
"""
import argparse
import random

import numpy as np
import torch
import torchvision.transforms as transforms
from models.tag2text import tag2text_caption
from PIL import Image

parser = argparse.ArgumentParser(
    description="Tag2Text inferece for tagging and captioning"
)
parser.add_argument(
    "--image",
    metavar="DIR",
    help="path to dataset",
    default="images/1641173_2291260800.jpg",
)
parser.add_argument(
    "--pretrained",
    metavar="DIR",
    help="path to pretrained model",
    default="pretrained/tag2text_swin_14m.pth",
)
parser.add_argument(
    "--image-size",
    default=384,
    type=int,
    metavar="N",
    help="input image size (default: 448)",
)
parser.add_argument(
    "--thre", default=0.68, type=float, metavar="N", help="threshold value"
)
parser.add_argument(
    "--specified-tags", default="None", help="User input specified tags"
)


def inference(image, model, input_tag="None"):
    with torch.no_grad():
        caption, tag_predict = model.generate(
            image, tag_input=None, max_length=50, return_tag_predict=True
        )

    if input_tag == "" or input_tag == "none" or input_tag == "None":
        return tag_predict[0], None, caption[0]

    # If user input specified tags:
    else:
        input_tag_list = []
        input_tag_list.append(input_tag.replace(",", " | "))

        with torch.no_grad():
            caption, input_tag = model.generate(
                image, tag_input=input_tag_list, max_length=50, return_tag_predict=True
            )

        return tag_predict[0], input_tag[0], caption[0]


if __name__ == "__main__":
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    transform = transforms.Compose(
        [
            transforms.Resize((args.image_size, args.image_size)),
            transforms.ToTensor(),
            normalize,
        ]
    )

    # delete some tags that may disturb captioning
    # 127: "quarter"; 2961: "back", 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
    delete_tag_index = [127, 2961, 3351, 3265, 3338, 3355, 3359]

    #######load model
    model = tag2text_caption(
        pretrained=args.pretrained,
        image_size=args.image_size,
        vit="swin_b",
        delete_tag_index=delete_tag_index,
    )
    model.threshold = args.thre  # threshold for tagging
    model.eval()

    model = model.to(device)
    raw_image = Image.open(args.image).resize((args.image_size, args.image_size))
    image = transform(raw_image).unsqueeze(0).to(device)

    res = inference(image, model, args.specified_tags)
    print("Model Identified Tags: ", res[0])
    print("User Specified Tags: ", res[1])
    print("Image Caption: ", res[2])