Spaces:
Runtime error
Runtime error
File size: 2,920 Bytes
c426a27 ff883a7 c426a27 ff883a7 c426a27 ff883a7 c426a27 ff883a7 5c74464 c426a27 5c74464 c426a27 ff883a7 c426a27 af88c78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch
from PIL import Image, ImageDraw, ImageOps
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import json
import pdb
import cv2
import numpy as np
from typing import Union
from tools import is_platform_win
from .base_captioner import BaseCaptioner
class BLIP2Captioner(BaseCaptioner):
def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
super().__init__(device, enable_filter)
self.device = device
self.dialogue = dialogue
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
if is_platform_win():
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="sequential", torch_dtype=self.torch_dtype)
else:
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map='sequential', load_in_8bit=True)
@torch.no_grad()
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
if type(image) == str: # input path
image = Image.open(image)
if not self.dialogue:
text_prompt = 'Question: what does the image show? Answer:'
inputs = self.processor(image, text = text_prompt, return_tensors="pt").to(self.device, self.torch_dtype)
out = self.model.generate(**inputs, max_new_tokens=50)
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
if self.enable_filter and filter:
captions = self.filter_caption(image, captions)
print(f"\nProcessed ImageCaptioning by BLIP2Captioner, Output Text: {captions}")
return captions
else:
context = []
template = "Question: {} Answer: {}."
while(True):
input_texts = input()
if input_texts == 'end':
break
prompt = " ".join([template.format(context[i][0], context[i][1]) for i in range(len(context))]) + " Question: " + input_texts + " Answer:"
inputs = self.processor(image, text = prompt, return_tensors="pt").to(self.device, self.torch_dtype)
out = self.model.generate(**inputs, max_new_tokens=50)
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
context.append((input_texts, captions))
return captions
if __name__ == '__main__':
dialogue = False
model = BLIP2Captioner(device='cuda:4', dialogue = dialogue, cache_dir = '/nvme-ssd/fjj/Caption-Anything/model_cache')
image_path = 'test_img/img2.jpg'
seg_mask = np.zeros((224,224))
seg_mask[50:200, 50:200] = 1
print(f'process image {image_path}')
print(model.inference_seg(image_path, seg_mask))
|