Spaces:
Build error
Build error
import os | |
import argparse | |
import pdb | |
import time | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
from caption_anything.captioner import build_captioner, BaseCaptioner | |
from caption_anything.segmenter import build_segmenter | |
from caption_anything.text_refiner import build_text_refiner | |
class CaptionAnything: | |
def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None): | |
self.args = args | |
self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner | |
self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter | |
self.text_refiner = None | |
if not args.disable_gpt: | |
if text_refiner is not None: | |
self.text_refiner = text_refiner | |
else: | |
self.init_refiner(api_key) | |
def image_embedding(self): | |
return self.segmenter.image_embedding | |
def image_embedding(self, image_embedding): | |
self.segmenter.image_embedding = image_embedding | |
def original_size(self): | |
return self.segmenter.predictor.original_size | |
def original_size(self, original_size): | |
self.segmenter.predictor.original_size = original_size | |
def input_size(self): | |
return self.segmenter.predictor.input_size | |
def input_size(self, input_size): | |
self.segmenter.predictor.input_size = input_size | |
def setup(self, image_embedding, original_size, input_size, is_image_set): | |
self.image_embedding = image_embedding | |
self.original_size = original_size | |
self.input_size = input_size | |
self.segmenter.predictor.is_image_set = is_image_set | |
def init_refiner(self, api_key): | |
try: | |
self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key) | |
self.text_refiner.llm('hi') # test | |
except: | |
self.text_refiner = None | |
print('OpenAI GPT is not available') | |
def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False): | |
# TODO: Add support to multiple seg masks. | |
# segment with prompt | |
print("CA prompt: ", prompt, "CA controls", controls) | |
seg_mask = self.segmenter.inference(image, prompt)[0, ...] | |
if self.args.enable_morphologyex: | |
seg_mask = 255 * seg_mask.astype(np.uint8) | |
seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1) | |
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel=np.ones((6, 6), np.uint8)) | |
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel=np.ones((6, 6), np.uint8)) | |
seg_mask = seg_mask[:, :, 0] > 0 | |
mask_save_path = f'result/mask_{time.time()}.png' | |
if not os.path.exists(os.path.dirname(mask_save_path)): | |
os.makedirs(os.path.dirname(mask_save_path)) | |
seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.) | |
if seg_mask_img.mode != 'RGB': | |
seg_mask_img = seg_mask_img.convert('RGB') | |
seg_mask_img.save(mask_save_path) | |
print('seg_mask path: ', mask_save_path) | |
print("seg_mask.shape: ", seg_mask.shape) | |
# captioning with mask | |
if self.args.enable_reduce_tokens: | |
caption, crop_save_path = self.captioner. \ | |
inference_with_reduced_tokens(image, seg_mask, | |
crop_mode=self.args.seg_crop_mode, | |
filter=self.args.clip_filter, | |
disable_regular_box=self.args.disable_regular_box) | |
else: | |
caption, crop_save_path = self.captioner. \ | |
inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, | |
filter=self.args.clip_filter, | |
disable_regular_box=self.args.disable_regular_box) | |
# refining with TextRefiner | |
context_captions = [] | |
if self.args.context_captions: | |
context_captions.append(self.captioner.inference(image)) | |
if not disable_gpt and self.text_refiner is not None: | |
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions, | |
enable_wiki=enable_wiki) | |
else: | |
refined_caption = {'raw_caption': caption} | |
out = {'generated_captions': refined_caption, | |
'crop_save_path': crop_save_path, | |
'mask_save_path': mask_save_path, | |
'mask': seg_mask_img, | |
'context_captions': context_captions} | |
return out | |
if __name__ == "__main__": | |
from caption_anything.utils.parser import parse_augment | |
args = parse_augment() | |
# image_path = 'test_images/img3.jpg' | |
image_path = 'test_images/img1.jpg' | |
prompts = [ | |
{ | |
"prompt_type": ["click"], | |
"input_point": [[500, 300], [200, 500]], | |
"input_label": [1, 0], | |
"multimask_output": "True", | |
}, | |
{ | |
"prompt_type": ["click"], | |
"input_point": [[300, 800]], | |
"input_label": [1], | |
"multimask_output": "True", | |
} | |
] | |
controls = { | |
"length": "30", | |
"sentiment": "positive", | |
# "imagination": "True", | |
"imagination": "False", | |
"language": "English", | |
} | |
model = CaptionAnything(args, os.environ['OPENAI_API_KEY']) | |
for prompt in prompts: | |
print('*' * 30) | |
print('Image path: ', image_path) | |
image = Image.open(image_path) | |
print(image) | |
print('Visual controls (SAM prompt):\n', prompt) | |
print('Language controls:\n', controls) | |
out = model.inference(image_path, prompt, controls) | |