File size: 3,554 Bytes
c426a27
f1a2810
c426a27
f1a2810
 
c426a27
 
 
 
f1a2810
c426a27
 
 
 
 
 
 
 
f1a2810
 
 
c426a27
ccb14a3
f1a2810
c426a27
 
5c74464
ccb14a3
 
c426a27
ccb14a3
 
 
c426a27
ccb14a3
f1a2810
c426a27
f1a2810
 
ccb14a3
f1a2810
 
 
c426a27
 
 
 
 
 
 
 
 
5c74464
c426a27
ccb14a3
 
 
c426a27
ccb14a3
c426a27
 
 
 
9a84ec8
 
f1a2810
c426a27
9a84ec8
 
 
c426a27
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from PIL import Image
from transformers import BlipProcessor

from caption_anything.utils.utils import load_image
from .modeling_blip import BlipForConditionalGeneration
import numpy as np
from typing import Union
from .base_captioner import BaseCaptioner
import torchvision.transforms.functional as F


class BLIPCaptioner(BaseCaptioner):
    def __init__(self, device, enable_filter=False):
        super().__init__(device, enable_filter)
        self.device = device
        self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
        self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
        self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large",
                                                                  torch_dtype=self.torch_dtype).to(self.device)

    @torch.no_grad()
    def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False, args={}):
        image = load_image(image, return_type="pil")
        inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype)
        out = self.model.generate(**inputs, max_new_tokens=50)
        captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
        
        result = {}
        if self.enable_filter and filter:
            clip_score = self.filter_caption(image, captions)
            result['clip_score'] = clip_score
        result.update({'caption':captions})
        print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
        return {'caption': captions}

    @torch.no_grad()
    def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg",
                                      filter=False, disable_regular_box=False):
        result = {}
        crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode,
                                                         disable_regular_box=disable_regular_box)
        image = load_image(image, return_type="pil")
        inputs = self.processor(image, return_tensors="pt")
        pixel_values = inputs.pixel_values.to(self.device, self.torch_dtype)
        _, _, H, W = pixel_values.shape
        seg_mask = Image.fromarray(seg_mask.astype(float))
        seg_mask = seg_mask.resize((H, W))
        seg_mask = F.pil_to_tensor(seg_mask) > 0.5
        seg_mask = seg_mask.float()
        pixel_masks = seg_mask.unsqueeze(0).to(self.device)
        out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
        captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
        if self.enable_filter and filter:
            clip_score = self.filter_caption(image, captions)
            result['clip_score'] = clip_score
        result.update({'caption':captions, 'crop_save_path':crop_save_path})
        print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
        return result


if __name__ == '__main__':
    model = BLIPCaptioner(device='cuda:0')
    # image_path = 'test_images/img2.jpg'
    image_path = 'image/SAM/img10.jpg'
    seg_mask = np.zeros((15, 15))
    seg_mask[5:10, 5:10] = 1
    seg_mask = 'test_images/img10.jpg.raw_mask.png'
    image_path = 'test_images/img2.jpg'
    seg_mask = 'test_images/img2.jpg.raw_mask.png'
    print(f'process image {image_path}')
    print(model.inference_with_reduced_tokens(image_path, seg_mask))