File size: 2,920 Bytes
c426a27
 
 
 
 
 
 
 
ff883a7
 
c426a27
 
 
 
 
 
 
 
 
ff883a7
 
 
 
 
c426a27
 
 
ff883a7
c426a27
 
ff883a7
5c74464
c426a27
5c74464
c426a27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff883a7
c426a27
 
 
 
 
 
 
 
 
 
af88c78
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from PIL import Image, ImageDraw, ImageOps
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import json
import pdb
import cv2
import numpy as np
from typing import Union

from tools import is_platform_win
from .base_captioner import BaseCaptioner

class BLIP2Captioner(BaseCaptioner):
    def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
        super().__init__(device, enable_filter)
        self.device = device
        self.dialogue = dialogue
        self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
        self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
        if is_platform_win():
            self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="sequential", torch_dtype=self.torch_dtype)
        else:
            self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map='sequential', load_in_8bit=True)

    @torch.no_grad()
    def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
        if type(image) == str: # input path
            image = Image.open(image)

        if not self.dialogue:
            text_prompt = 'Question: what does the image show? Answer:'
            inputs = self.processor(image, text = text_prompt, return_tensors="pt").to(self.device, self.torch_dtype)
            out = self.model.generate(**inputs, max_new_tokens=50)
            captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
            if self.enable_filter and filter:
                captions = self.filter_caption(image, captions)
            print(f"\nProcessed ImageCaptioning by BLIP2Captioner, Output Text: {captions}")
            return captions
        else:
            context = []
            template = "Question: {} Answer: {}."
            while(True):
                input_texts = input()
                if input_texts == 'end':
                    break
                prompt = " ".join([template.format(context[i][0], context[i][1]) for i in range(len(context))]) + " Question: " + input_texts + " Answer:"
                inputs = self.processor(image, text = prompt, return_tensors="pt").to(self.device, self.torch_dtype)
                out = self.model.generate(**inputs, max_new_tokens=50)
                captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
                context.append((input_texts, captions))

        return captions

if __name__ == '__main__':

    dialogue = False
    model = BLIP2Captioner(device='cuda:4', dialogue = dialogue, cache_dir = '/nvme-ssd/fjj/Caption-Anything/model_cache')
    image_path = 'test_img/img2.jpg'
    seg_mask = np.zeros((224,224))
    seg_mask[50:200, 50:200] = 1
    print(f'process image {image_path}')
    print(model.inference_seg(image_path, seg_mask))