import torch from PIL import Image, ImageDraw, ImageOps from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering import json import pdb import cv2 import numpy as np from typing import Any, Union, List import time import clip from caption_anything.utils.utils import load_image def boundary(inputs): col = inputs.shape[1] inputs = inputs.reshape(-1) lens = len(inputs) start = np.argmax(inputs) end = lens - 1 - np.argmax(np.flip(inputs)) top = start // col bottom = end // col return top, bottom def new_seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]): if type(seg_mask) == str: seg_mask = Image.open(seg_mask) elif type(seg_mask) == np.ndarray: seg_mask = Image.fromarray(seg_mask) seg_mask = np.array(seg_mask) > 0 size = max(seg_mask.shape[0], seg_mask.shape[1]) top, bottom = boundary(seg_mask) left, right = boundary(seg_mask.T) return [left / size, top / size, right / size, bottom / size] def seg_to_box(seg_mask: Union[np.ndarray, Image.Image, str]): if type(seg_mask) == str: seg_mask = cv2.imread(seg_mask, cv2.IMREAD_GRAYSCALE) _, seg_mask = cv2.threshold(seg_mask, 127, 255, 0) elif type(seg_mask) == np.ndarray: assert seg_mask.ndim == 2 # only support single-channel segmentation mask seg_mask = seg_mask.astype('uint8') if seg_mask.dtype == 'bool': seg_mask = seg_mask * 255 contours, hierarchy = cv2.findContours(seg_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) contours = np.concatenate(contours, axis=0) rect = cv2.minAreaRect(contours) box = cv2.boxPoints(rect) if rect[-1] >= 45: newstart = box.argmin(axis=0)[1] # leftmost else: newstart = box.argmax(axis=0)[0] # topmost box = np.concatenate([box[newstart:], box[:newstart]], axis=0) box = np.int0(box) return box def get_w_h(rect_points): w = np.linalg.norm(rect_points[0] - rect_points[1], ord=2).astype('int') h = np.linalg.norm(rect_points[0] - rect_points[3], ord=2).astype('int') return w, h def cut_box(img, rect_points): w, h = get_w_h(rect_points) dst_pts = np.array([[h, 0], [h, w], [0, w], [0, 0], ], dtype="float32") transform = cv2.getPerspectiveTransform(rect_points.astype("float32"), dst_pts) cropped_img = cv2.warpPerspective(img, transform, (h, w)) return cropped_img class BaseCaptioner: def __init__(self, device, enable_filter=False): print(f"Initializing ImageCaptioning to {device}") self.device = device self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 self.processor = None self.model = None self.enable_filter = enable_filter if enable_filter: self.filter, self.preprocess = clip.load('ViT-B/32', device) @torch.no_grad() def filter_caption(self, image: Union[np.ndarray, Image.Image, str], caption: str, reference_caption: List[str]=[]): image = load_image(image, return_type='pil') image = self.preprocess(image).unsqueeze(0).to(self.device) # (1, 3, 224, 224) captions = [caption] if len(reference_caption): captions.extend(reference_caption) text = clip.tokenize(captions).to(self.device) # (>1, 77) image_features = self.filter.encode_image(image) # (1, 512) text_features = self.filter.encode_text(text) # # (>1, 512) image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) if len(reference_caption): similarity = torch.matmul(image_features, text_features.transpose(1, 0)) / 0.07 similarity = similarity.softmax(dim=1)[0, 0].item() else: similarity = torch.matmul(image_features, text_features.transpose(1, 0)).item() print(f'Clip score of the caption is {similarity}') return similarity def inference(self, image: Union[np.ndarray, Image.Image, str], filter: bool = False): raise NotImplementedError() def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, filter: bool = False): raise NotImplementedError() def inference_box(self, image: Union[np.ndarray, Image.Image, str], box: Union[list, np.ndarray], filter=False, verbose=False, caption_args={}): image = load_image(image, return_type="pil") if np.array(box).size == 4: # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners size = max(image.width, image.height) x1, y1, x2, y2 = box image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size))) elif np.array(box).size == 8: # four corners of an irregular rectangle image_crop = cut_box(np.array(image), box) crop_save_path = None if verbose: crop_save_path = f'result/crop_{time.time()}.png' Image.fromarray(image_crop).save(crop_save_path) print(f'croped image saved in {crop_save_path}') caption = self.inference(image_crop, filter, caption_args) caption.update({'crop_save_path': crop_save_path}) return caption def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str] = None, crop_mode="w_bg", filter=False, disable_regular_box=False, verbose=False, caption_args={}): if seg_mask is None: seg_mask = np.ones(image.size).astype(bool) image = load_image(image, return_type="pil") seg_mask = load_image(seg_mask, return_type="pil") seg_mask = seg_mask.resize(image.size) seg_mask = np.array(seg_mask) > 0 if crop_mode == "wo_bg": image = np.array(image) * seg_mask[:, :, np.newaxis] + (1 - seg_mask[:, :, np.newaxis]) * 255 image = np.uint8(image) else: image = np.array(image) if disable_regular_box: min_area_box = seg_to_box(seg_mask) else: min_area_box = new_seg_to_box(seg_mask) return self.inference_box(image, min_area_box, filter, verbose, caption_args) def generate_seg_cropped_image(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", disable_regular_box=False): image = load_image(image, return_type="pil") seg_mask = load_image(seg_mask, return_type="pil") seg_mask = seg_mask.resize(image.size) seg_mask = np.array(seg_mask) > 0 if crop_mode == "wo_bg": image = np.array(image) * seg_mask[:, :, np.newaxis] + (1 - seg_mask[:, :, np.newaxis]) * 255 else: image = np.array(image) if disable_regular_box: box = seg_to_box(seg_mask) else: box = new_seg_to_box(seg_mask) if np.array(box).size == 4: # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners size = max(image.shape[0], image.shape[1]) x1, y1, x2, y2 = box image_crop = np.array(image.crop((x1 * size, y1 * size, x2 * size, y2 * size))) elif np.array(box).size == 8: # four corners of an irregular rectangle image_crop = cut_box(np.array(image), box) crop_save_path = f'result/crop_{time.time()}.png' Image.fromarray(image_crop).save(crop_save_path) print(f'croped image saved in {crop_save_path}') return crop_save_path if __name__ == '__main__': model = BaseCaptioner(device='cuda:0') image_path = 'test_images/img2.jpg' seg_mask = np.zeros((15, 15)) seg_mask[5:10, 5:10] = 1 seg_mask = 'image/SAM/img10.jpg.raw_mask.png' print(model.inference_seg(image_path, seg_mask))