visual-chatgpt-zh-vits / modules /image_captioning.py
FrankZxShen's picture
Upload 55 files
aa69275
raw
history blame contribute delete
No virus
1.22 kB
import torch
from modules.utils import *
class ImageCaptioning:
def __init__(self, device, pretrained_model_dir):
print("Initializing ImageCaptioning to %s" % device)
self.device = device
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.processor = BlipProcessor.from_pretrained(f"{pretrained_model_dir}/blip-image-captioning-base")
self.model = BlipForConditionalGeneration.from_pretrained(
f"{pretrained_model_dir}/blip-image-captioning-base", torch_dtype=self.torch_dtype).to(self.device)
@prompts(name="Get Photo Description",
description="useful when you want to know what is inside the photo. receives image_path as input. "
"The input to this tool should be a string, representing the image_path. ")
def inference(self, image_path):
inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device, self.torch_dtype)
out = self.model.generate(**inputs)
captions = self.processor.decode(out[0], skip_special_tokens=True)
print(f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text: {captions}")
return captions