|
import torch |
|
from modules.utils import * |
|
|
|
class ImageCaptioning: |
|
def __init__(self, device, pretrained_model_dir): |
|
print("Initializing ImageCaptioning to %s" % device) |
|
self.device = device |
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
self.processor = BlipProcessor.from_pretrained(f"{pretrained_model_dir}/blip-image-captioning-base") |
|
self.model = BlipForConditionalGeneration.from_pretrained( |
|
f"{pretrained_model_dir}/blip-image-captioning-base", torch_dtype=self.torch_dtype).to(self.device) |
|
|
|
@prompts(name="Get Photo Description", |
|
description="useful when you want to know what is inside the photo. receives image_path as input. " |
|
"The input to this tool should be a string, representing the image_path. ") |
|
def inference(self, image_path): |
|
inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device, self.torch_dtype) |
|
out = self.model.generate(**inputs) |
|
captions = self.processor.decode(out[0], skip_special_tokens=True) |
|
print(f"\nProcessed ImageCaptioning, Input Image: {image_path}, Output Text: {captions}") |
|
return captions |