Spaces:
Build error
Build error
File size: 5,974 Bytes
3b8cdb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import argparse
import pdb
import time
from PIL import Image
import cv2
import numpy as np
from caption_anything.captioner import build_captioner, BaseCaptioner
from caption_anything.segmenter import build_segmenter
from caption_anything.text_refiner import build_text_refiner
class CaptionAnything:
def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
self.args = args
self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
self.text_refiner = None
if not args.disable_gpt:
if text_refiner is not None:
self.text_refiner = text_refiner
else:
self.init_refiner(api_key)
@property
def image_embedding(self):
return self.segmenter.image_embedding
@image_embedding.setter
def image_embedding(self, image_embedding):
self.segmenter.image_embedding = image_embedding
@property
def original_size(self):
return self.segmenter.predictor.original_size
@original_size.setter
def original_size(self, original_size):
self.segmenter.predictor.original_size = original_size
@property
def input_size(self):
return self.segmenter.predictor.input_size
@input_size.setter
def input_size(self, input_size):
self.segmenter.predictor.input_size = input_size
def setup(self, image_embedding, original_size, input_size, is_image_set):
self.image_embedding = image_embedding
self.original_size = original_size
self.input_size = input_size
self.segmenter.predictor.is_image_set = is_image_set
def init_refiner(self, api_key):
try:
self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
self.text_refiner.llm('hi') # test
except:
self.text_refiner = None
print('OpenAI GPT is not available')
def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
# TODO: Add support to multiple seg masks.
# segment with prompt
print("CA prompt: ", prompt, "CA controls", controls)
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
if self.args.enable_morphologyex:
seg_mask = 255 * seg_mask.astype(np.uint8)
seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1)
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel=np.ones((6, 6), np.uint8))
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel=np.ones((6, 6), np.uint8))
seg_mask = seg_mask[:, :, 0] > 0
mask_save_path = f'result/mask_{time.time()}.png'
if not os.path.exists(os.path.dirname(mask_save_path)):
os.makedirs(os.path.dirname(mask_save_path))
seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.)
if seg_mask_img.mode != 'RGB':
seg_mask_img = seg_mask_img.convert('RGB')
seg_mask_img.save(mask_save_path)
print('seg_mask path: ', mask_save_path)
print("seg_mask.shape: ", seg_mask.shape)
# captioning with mask
if self.args.enable_reduce_tokens:
caption, crop_save_path = self.captioner. \
inference_with_reduced_tokens(image, seg_mask,
crop_mode=self.args.seg_crop_mode,
filter=self.args.clip_filter,
disable_regular_box=self.args.disable_regular_box)
else:
caption, crop_save_path = self.captioner. \
inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode,
filter=self.args.clip_filter,
disable_regular_box=self.args.disable_regular_box)
# refining with TextRefiner
context_captions = []
if self.args.context_captions:
context_captions.append(self.captioner.inference(image))
if not disable_gpt and self.text_refiner is not None:
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions,
enable_wiki=enable_wiki)
else:
refined_caption = {'raw_caption': caption}
out = {'generated_captions': refined_caption,
'crop_save_path': crop_save_path,
'mask_save_path': mask_save_path,
'mask': seg_mask_img,
'context_captions': context_captions}
return out
if __name__ == "__main__":
from caption_anything.utils.parser import parse_augment
args = parse_augment()
# image_path = 'test_images/img3.jpg'
image_path = 'test_images/img1.jpg'
prompts = [
{
"prompt_type": ["click"],
"input_point": [[500, 300], [200, 500]],
"input_label": [1, 0],
"multimask_output": "True",
},
{
"prompt_type": ["click"],
"input_point": [[300, 800]],
"input_label": [1],
"multimask_output": "True",
}
]
controls = {
"length": "30",
"sentiment": "positive",
# "imagination": "True",
"imagination": "False",
"language": "English",
}
model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
for prompt in prompts:
print('*' * 30)
print('Image path: ', image_path)
image = Image.open(image_path)
print(image)
print('Visual controls (SAM prompt):\n', prompt)
print('Language controls:\n', controls)
out = model.inference(image_path, prompt, controls)
|