File size: 3,198 Bytes
e62d81d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import argparse
import numpy as np
import argparse
from PIL import Image
from bliva.models import load_model_and_preprocess

def disable_torch_init():
        """
        Disable the redundant torch default initialization to accelerate model creation.
        """
        import torch
        setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
        setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
        
def parse_args():
    """
    Parse arguments from command line.
    """
    parser = argparse.ArgumentParser(description="Arguments for Evaluation")
    parser.add_argument(
        "--answer_mc",
        action="store_true",
        default=False,
        help="Whether to evaluate multiple choice question with candidates."
    )
    parser.add_argument(    
        "--answer_qs",
        action="store_true",
        default=False,
        help="Whether to evaluate only one question image."
    )

    parser.add_argument("--model_name", type=str, default="bliva_vicuna")
    parser.add_argument("--device", type=str, default="cuda:0", help="Specify which gpu device to use.")
    parser.add_argument("--img_path", type=str, required=True, help="the path to the image")
    parser.add_argument("--question", type=str, required=True,  help="the question to ask")
    parser.add_argument("--candidates", type=str,  help="list of choices for mulitple choice question")

    args = parser.parse_args()
    return args

def eval_one(image, question, model):
    """
    Evaluate one question
    """
    outputs = model.generate({"image": image, "prompt": question})
    print("=====================================")
    print("Question:", question[0])
    print("-------------------------------------")
    print("Outputs: ", outputs[0])


def eval_candidates(image, question, candidates, model):
    """
    Evaluate with candidates
    """
    outputs = model.predict_class({"image": image, "prompt": question}, candidates)
    print("=====================================")
    print("Question:", question[0])
    print("-------------------------------------")
    print("Candidates:", candidates)
    print("-------------------------------------")
    print("Outputs: ", candidates[outputs[0][0]])



def main(args):
    np.random.seed(0)
         
    disable_torch_init()
    
    if args.model_name == "bliva_vicuna":
        model, vis_processors, _ = load_model_and_preprocess(name=args.model_name, model_type="vicuna7b", is_eval=True, device=args.device)
    if args.model_name == "bliva_flant5":
        model, vis_processors, _ = load_model_and_preprocess(name=args.model_name, model_type="flant5xxl", is_eval=True, device=args.device)
    vis_processor = vis_processors["eval"]
    
    image = Image.open(args.img_path).convert('RGB')

    question = [args.question]
    
    image = vis_processor(image).unsqueeze(0).to(args.device)
   
    if args.answer_qs:
        eval_one(image, question, model)
    elif args.answer_mc:
        candidates = [candidate.strip() for candidate in args.candidates.split(",")]
        eval_candidates(image, question, candidates, model)


if __name__ == "__main__":
    args = parse_args()
    main(args)