Spaces:
Runtime error
Runtime error
import os | |
import argparse | |
import pdb | |
import time | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
import easyocr | |
import copy | |
import time | |
from caption_anything.captioner import build_captioner, BaseCaptioner | |
from caption_anything.segmenter import build_segmenter, build_segmenter_densecap | |
from caption_anything.text_refiner import build_text_refiner | |
from caption_anything.utils.utils import prepare_segmenter, seg_model_map, load_image, get_image_shape | |
from caption_anything.utils.utils import mask_painter_foreground_all, mask_painter, xywh_to_x1y1x2y2, image_resize | |
from caption_anything.utils.densecap_painter import draw_bbox | |
class CaptionAnything: | |
def __init__(self, args, api_key="", captioner=None, segmenter=None, ocr_reader=None, text_refiner=None): | |
self.args = args | |
self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner | |
self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter | |
self.segmenter_densecap = build_segmenter_densecap(args.segmenter, args.device, args, model=self.segmenter.model) | |
self.ocr_lang = ["ch_tra", "en"] | |
self.ocr_reader = ocr_reader if ocr_reader is not None else easyocr.Reader(self.ocr_lang) | |
self.text_refiner = None | |
if not args.disable_gpt: | |
if text_refiner is not None: | |
self.text_refiner = text_refiner | |
elif api_key != "": | |
self.init_refiner(api_key) | |
self.require_caption_prompt = args.captioner == 'blip2' | |
print('text_refiner init time: ', time.time() - t0) | |
def image_embedding(self): | |
return self.segmenter.image_embedding | |
def image_embedding(self, image_embedding): | |
self.segmenter.image_embedding = image_embedding | |
def original_size(self): | |
return self.segmenter.predictor.original_size | |
def original_size(self, original_size): | |
self.segmenter.predictor.original_size = original_size | |
def input_size(self): | |
return self.segmenter.predictor.input_size | |
def input_size(self, input_size): | |
self.segmenter.predictor.input_size = input_size | |
def setup(self, image_embedding, original_size, input_size, is_image_set): | |
self.image_embedding = image_embedding | |
self.original_size = original_size | |
self.input_size = input_size | |
self.segmenter.predictor.is_image_set = is_image_set | |
def init_refiner(self, api_key): | |
try: | |
self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key) | |
self.text_refiner.llm('hi') # test | |
except: | |
self.text_refiner = None | |
print('OpenAI GPT is not available') | |
def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False, verbose=False, is_densecap=False, args={}): | |
# segment with prompt | |
print("CA prompt: ", prompt, "CA controls", controls) | |
is_seg_everything = 'everything' in prompt['prompt_type'] | |
args['seg_crop_mode'] = args.get('seg_crop_mode', self.args.seg_crop_mode) | |
args['clip_filter'] = args.get('clip_filter', self.args.clip_filter) | |
args['disable_regular_box'] = args.get('disable_regular_box', self.args.disable_regular_box) | |
args['context_captions'] = args.get('context_captions', self.args.context_captions) | |
args['enable_reduce_tokens'] = args.get('enable_reduce_tokens', self.args.enable_reduce_tokens) | |
args['enable_morphologyex'] = args.get('enable_morphologyex', self.args.enable_morphologyex) | |
args['topN'] = args.get('topN', 10) if is_seg_everything else 1 | |
args['min_mask_area'] = args.get('min_mask_area', 0) | |
if not is_densecap: | |
seg_results = self.segmenter.inference(image, prompt) | |
else: | |
seg_results = self.segmenter_densecap.inference(image, prompt) | |
seg_masks, seg_bbox, seg_area = seg_results if is_seg_everything else (seg_results, None, None) | |
if args['topN'] > 1: # sort by area | |
samples = list(zip(*[seg_masks, seg_bbox, seg_area])) | |
# top_samples = sorted(samples, key=lambda x: x[2], reverse=True) | |
# seg_masks, seg_bbox, seg_area = list(zip(*top_samples)) | |
samples = list(filter(lambda x: x[2] > args['min_mask_area'], samples)) | |
samples = samples[:args['topN']] | |
seg_masks, seg_bbox, seg_area = list(zip(*samples)) | |
out_list = [] | |
for i, seg_mask in enumerate(seg_masks): | |
if args['enable_morphologyex']: | |
seg_mask = 255 * seg_mask.astype(np.uint8) | |
seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1) | |
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel=np.ones((6, 6), np.uint8)) | |
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel=np.ones((6, 6), np.uint8)) | |
seg_mask = seg_mask[:, :, 0] > 0 | |
seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.) | |
mask_save_path = None | |
if verbose: | |
mask_save_path = f'result/mask_{time.time()}.png' | |
if not os.path.exists(os.path.dirname(mask_save_path)): | |
os.makedirs(os.path.dirname(mask_save_path)) | |
if seg_mask_img.mode != 'RGB': | |
seg_mask_img = seg_mask_img.convert('RGB') | |
seg_mask_img.save(mask_save_path) | |
print('seg_mask path: ', mask_save_path) | |
print("seg_mask.shape: ", seg_mask.shape) | |
# captioning with mask | |
if args['enable_reduce_tokens']: | |
result = self.captioner.inference_with_reduced_tokens(image, seg_mask, | |
crop_mode=args['seg_crop_mode'], | |
filter=args['clip_filter'], | |
disable_regular_box=args['disable_regular_box'], | |
verbose=verbose, | |
caption_args=args) | |
else: | |
result = self.captioner.inference_seg(image, seg_mask, | |
crop_mode=args['seg_crop_mode'], | |
filter=args['clip_filter'], | |
disable_regular_box=args['disable_regular_box'], | |
verbose=verbose, | |
caption_args=args) | |
caption = result.get('caption', None) | |
crop_save_path = result.get('crop_save_path', None) | |
# refining with TextRefiner | |
context_captions = [] | |
if args['context_captions']: | |
context_captions.append(self.captioner.inference(image)['caption']) | |
if not disable_gpt and self.text_refiner is not None: | |
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions, | |
enable_wiki=enable_wiki) | |
else: | |
refined_caption = {'raw_caption': caption} | |
out = {'generated_captions': refined_caption, | |
'crop_save_path': crop_save_path, | |
'mask_save_path': mask_save_path, | |
'mask': seg_mask_img, | |
'bbox': seg_bbox[i] if seg_bbox is not None else None, | |
'area': seg_area[i] if seg_area is not None else None, | |
'context_captions': context_captions, | |
'ppl_score': result.get('ppl_score', -100.), | |
'clip_score': result.get('clip_score', 0.) | |
} | |
out_list.append(out) | |
return out_list | |
def parse_dense_caption(self, image, topN=10, reference_caption=[], verbose=False): | |
width, height = get_image_shape(image) | |
prompt = {'prompt_type': ['everything']} | |
densecap_args = { | |
'return_ppl': True, | |
'clip_filter': True, | |
'reference_caption': reference_caption, | |
'text_prompt': "", # 'Question: what does the image show? Answer:' | |
'seg_crop_mode': 'w_bg', | |
# 'text_prompt': "", | |
# 'seg_crop_mode': 'wo_bg', | |
'disable_regular_box': False, | |
'topN': topN, | |
'min_ppl_score': -1.8, | |
'min_clip_score': 0.30, | |
'min_mask_area': 2500, | |
} | |
dense_captions = self.inference(image, prompt, | |
controls=None, | |
disable_gpt=True, | |
verbose=verbose, | |
is_densecap=True, | |
args=densecap_args) | |
print('Process Dense Captioning: \n', dense_captions) | |
dense_captions = list(filter(lambda x: x['ppl_score'] / (1+len(x['generated_captions']['raw_caption'].split())) >= densecap_args['min_ppl_score'], dense_captions)) | |
dense_captions = list(filter(lambda x: x['clip_score'] >= densecap_args['min_clip_score'], dense_captions)) | |
dense_cap_prompt = [] | |
for cap in dense_captions: | |
x, y, w, h = cap['bbox'] | |
cx, cy = x + w/2, (y + h/2) | |
dense_cap_prompt.append("({}: X:{:.0f}, Y:{:.0f}, Width:{:.0f}, Height:{:.0f})".format(cap['generated_captions']['raw_caption'], cx, cy, w, h)) | |
if verbose: | |
all_masks = [np.array(item['mask'].convert('P')) for item in dense_captions] | |
new_image = mask_painter_foreground_all(np.array(image), all_masks, background_alpha=0.4) | |
save_path = 'result/dense_caption_mask.png' | |
Image.fromarray(new_image).save(save_path) | |
print(f'Dense captioning mask saved in {save_path}') | |
vis_path = 'result/dense_caption_vis_{}.png'.format(time.time()) | |
dense_cap_painter_input = [{'bbox': xywh_to_x1y1x2y2(cap['bbox']), | |
'caption': cap['generated_captions']['raw_caption']} for cap in dense_captions] | |
draw_bbox(load_image(image, return_type='numpy'), vis_path, dense_cap_painter_input, show_caption=True) | |
print(f'Dense Captioning visualization saved in {vis_path}') | |
return ','.join(dense_cap_prompt) | |
def parse_ocr(self, image, thres=0.2): | |
width, height = get_image_shape(image) | |
image = load_image(image, return_type='numpy') | |
bounds = self.ocr_reader.readtext(image) | |
bounds = [bound for bound in bounds if bound[2] > thres] | |
print('Process OCR Text:\n', bounds) | |
ocr_prompt = [] | |
for box, text, conf in bounds: | |
p0, p1, p2, p3 = box | |
ocr_prompt.append('(\"{}\": X:{:.0f}, Y:{:.0f})'.format(text, (p0[0]+p1[0]+p2[0]+p3[0])/4, (p0[1]+p1[1]+p2[1]+p3[1])/4)) | |
ocr_prompt = '\n'.join(ocr_prompt) | |
# ocr_prompt = self.text_refiner.llm(f'The image have some scene texts with their locations: {ocr_prompt}. Please group these individual words into one or several phrase based on their relative positions (only give me your answer, do not show explanination)').strip() | |
# ocr_prefix1 = f'The image have some scene texts with their locations: {ocr_prompt}. Please group these individual words into one or several phrase based on their relative positions (only give me your answer, do not show explanination)' | |
# ocr_prefix2 = f'Please group these individual words into 1-3 phrases, given scene texts with their locations: {ocr_prompt}. You return is one or several strings and infer their locations. (only give me your answer like (“man working”, X: value, Y: value), do not show explanination)' | |
# ocr_prefix4 = f'summarize the individual scene text words detected by OCR tools into a fluent sentence based on their positions and distances. You should strictly describe all of the given scene text words. Do not miss any given word. Do not create non-exist words. Do not appear numeric positions. The individual words are given:\n{ocr_prompt}\n' | |
# ocr_prefix3 = f'combine the individual scene text words detected by OCR tools into one/several fluent phrases/sentences based on their positions and distances. You should strictly copy or correct all of the given scene text words. Do not miss any given word. Do not create non-exist words. The response is several strings seperate with their location (X, Y), each of which represents a phrase. The individual words are given:\n{ocr_prompt}\n' | |
# response = self.text_refiner.llm(ocr_prefix3).strip() if len(ocr_prompt) else "" | |
return ocr_prompt | |
def inference_cap_everything(self, image, verbose=False): | |
image = load_image(image, return_type='pil') | |
image = image_resize(image, res=1024) | |
width, height = get_image_shape(image) | |
other_args = {'text_prompt': ""} if self.require_caption_prompt else {} | |
img_caption = self.captioner.inference(image, filter=False, args=other_args)['caption'] | |
dense_caption_prompt = self.parse_dense_caption(image, topN=10, verbose=verbose, reference_caption=[]) | |
scene_text_prompt = self.parse_ocr(image, thres=0.2) | |
# scene_text_prompt = "N/A" | |
# the summarize_prompt is modified from https://github.com/JialianW/GRiT and https://github.com/showlab/Image2Paragraph | |
summarize_prompt = "Imagine you are a blind but intelligent image captioner. You should generate a descriptive, coherent and human-like paragraph based on the given information (a,b,c,d) instead of imagination:\na) Image Resolution: {image_size}\nb) Image Caption:{image_caption}\nc) Dense Caption: {dense_caption}\nd) Scene Text: {scene_text}\nThere are some rules for your response: Show objects with their attributes (e.g. position, color, size, shape, texture).\nPrimarily describe common objects with large size.\nProvide context of the image.\nShow relative position between objects.\nLess than 6 sentences.\nDo not appear number.\nDo not describe any individual letter.\nDo not show the image resolution.\nIngore the white background." | |
prompt = summarize_prompt.format(**{ | |
"image_size": "width {} height {}".format(width, height), | |
"image_caption":img_caption, | |
"dense_caption": dense_caption_prompt, | |
"scene_text": scene_text_prompt}) | |
print(f'caption everything prompt: {prompt}') | |
response = self.text_refiner.llm(prompt).strip() | |
# chinese_response = self.text_refiner.llm('Translate it into Chinese: {}'.format(response)).strip() | |
return response | |
if __name__ == "__main__": | |
from caption_anything.utils.parser import parse_augment | |
args = parse_augment() | |
image_path = 'result/wt/memes/87226084.jpg' | |
image = Image.open(image_path) | |
prompts = [ | |
{ | |
"prompt_type": ["click"], | |
"input_point": [[500, 300], [200, 500]], | |
"input_label": [1, 0], | |
"multimask_output": "True", | |
}, | |
# { | |
# "prompt_type": ["click"], | |
# "input_point": [[300, 800]], | |
# "input_label": [1], | |
# "multimask_output": "True", | |
# } | |
] | |
controls = { | |
"length": "30", | |
"sentiment": "positive", | |
# "imagination": "True", | |
"imagination": "False", | |
"language": "English", | |
} | |
model = CaptionAnything(args, os.environ['OPENAI_API_KEY']) | |
img_dir = 'test_images/memes' | |
for image_file in os.listdir(img_dir): | |
image_path = os.path.join(img_dir, image_file) | |
print('image_path:', image_path) | |
paragraph = model.inference_cap_everything(image_path, verbose=True) | |
print('Caption Everything:\n', paragraph) | |
ocr = model.parse_ocr(image_path) | |
print('OCR', ocr) |