visual-chatgpt-zh-vits / modules /visual_question_answering.py
FrankZxShen's picture
Upload 55 files
aa69275
raw
history blame contribute delete
No virus
1.51 kB
from modules.utils import *
class VisualQuestionAnswering:
def __init__(self, device, pretrained_model_dir):
print("Initializing VisualQuestionAnswering to %s" % device)
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.device = device
self.processor = BlipProcessor.from_pretrained(f"{pretrained_model_dir}/blip-vqa-base")
self.model = BlipForQuestionAnswering.from_pretrained(
f"{pretrained_model_dir}/blip-vqa-base", torch_dtype=self.torch_dtype).to(self.device)
@prompts(name="Answer Question About The Image",
description="useful when you need an answer for a question based on an image. "
"like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
"The input to this tool should be a comma seperated string of two, representing the image_path and the question")
def inference(self, inputs):
image_path, question = inputs.split(",")
raw_image = Image.open(image_path).convert('RGB')
inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device, self.torch_dtype)
out = self.model.generate(**inputs)
answer = self.processor.decode(out[0], skip_special_tokens=True)
print(f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, "
f"Output Answer: {answer}")
return answer