Spaces:
Runtime error
Runtime error
File size: 3,198 Bytes
e62d81d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import argparse
import numpy as np
import argparse
from PIL import Image
from bliva.models import load_model_and_preprocess
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def parse_args():
"""
Parse arguments from command line.
"""
parser = argparse.ArgumentParser(description="Arguments for Evaluation")
parser.add_argument(
"--answer_mc",
action="store_true",
default=False,
help="Whether to evaluate multiple choice question with candidates."
)
parser.add_argument(
"--answer_qs",
action="store_true",
default=False,
help="Whether to evaluate only one question image."
)
parser.add_argument("--model_name", type=str, default="bliva_vicuna")
parser.add_argument("--device", type=str, default="cuda:0", help="Specify which gpu device to use.")
parser.add_argument("--img_path", type=str, required=True, help="the path to the image")
parser.add_argument("--question", type=str, required=True, help="the question to ask")
parser.add_argument("--candidates", type=str, help="list of choices for mulitple choice question")
args = parser.parse_args()
return args
def eval_one(image, question, model):
"""
Evaluate one question
"""
outputs = model.generate({"image": image, "prompt": question})
print("=====================================")
print("Question:", question[0])
print("-------------------------------------")
print("Outputs: ", outputs[0])
def eval_candidates(image, question, candidates, model):
"""
Evaluate with candidates
"""
outputs = model.predict_class({"image": image, "prompt": question}, candidates)
print("=====================================")
print("Question:", question[0])
print("-------------------------------------")
print("Candidates:", candidates)
print("-------------------------------------")
print("Outputs: ", candidates[outputs[0][0]])
def main(args):
np.random.seed(0)
disable_torch_init()
if args.model_name == "bliva_vicuna":
model, vis_processors, _ = load_model_and_preprocess(name=args.model_name, model_type="vicuna7b", is_eval=True, device=args.device)
if args.model_name == "bliva_flant5":
model, vis_processors, _ = load_model_and_preprocess(name=args.model_name, model_type="flant5xxl", is_eval=True, device=args.device)
vis_processor = vis_processors["eval"]
image = Image.open(args.img_path).convert('RGB')
question = [args.question]
image = vis_processor(image).unsqueeze(0).to(args.device)
if args.answer_qs:
eval_one(image, question, model)
elif args.answer_mc:
candidates = [candidate.strip() for candidate in args.candidates.split(",")]
eval_candidates(image, question, candidates, model)
if __name__ == "__main__":
args = parse_args()
main(args)
|