austinmw's picture
Upload tool
a49a2e9
raw
history blame contribute delete
No virus
2.84 kB
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
from transformers.tools import PipelineTool
from transformers.tools.base import get_default_device
from transformers.utils import requires_backends
class InstructBLIPImageQuestionAnsweringTool(PipelineTool):
#default_checkpoint = "Salesforce/blip2-opt-2.7b"
#default_checkpoint = "Salesforce/instructblip-flan-t5-xl"
default_checkpoint = "Salesforce/instructblip-vicuna-7b"
#default_checkpoint = "Salesforce/instructblip-vicuna-13b"
description = (
"This is a tool that answers a question about an image. It takes an input named `image` which should be the "
"image containing the information, as well as a `question` which should be the question in English. It "
"returns a text that is the answer to the question."
)
name = "image_qa"
pre_processor_class = AutoProcessor
model_class = AutoModelForVision2Seq
inputs = ["image", "text"]
outputs = ["text"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
super().__init__(*args, **kwargs)
def setup(self):
"""
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
"""
if isinstance(self.pre_processor, str):
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
if isinstance(self.model, str):
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs, load_in_4bit=True, torch_dtype=torch.float16)
if self.post_processor is None:
self.post_processor = self.pre_processor
elif isinstance(self.post_processor, str):
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
if self.device is None:
if self.device_map is not None:
self.device = list(self.model.hf_device_map.values())[0]
else:
self.device = get_default_device()
self.is_initialized = True
def encode(self, image, question: str):
return self.pre_processor(images=image, text=question, return_tensors="pt").to(device="cuda", dtype=torch.float16)
def forward(self, inputs):
outputs = self.model.generate(
**inputs,
num_beams=5,
max_new_tokens=256,
min_length=1,
top_p=0.9,
repetition_penalty=1.5,
length_penalty=1.0,
temperature=0.7,
)
return outputs
def decode(self, outputs):
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()