Spaces:
Runtime error
Runtime error
ttengwang
commited on
Commit
•
108f2df
1
Parent(s):
b7e072a
share ocr_reader to accelerate inferenec
Browse files- app.py +11 -3
- caption_anything/captioner/blip2.py +2 -2
- caption_anything/model.py +8 -5
app.py
CHANGED
@@ -17,7 +17,7 @@ from caption_anything.text_refiner import build_text_refiner
|
|
17 |
from caption_anything.segmenter import build_segmenter
|
18 |
from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
|
19 |
from segment_anything import sam_model_registry
|
20 |
-
|
21 |
|
22 |
args = parse_augment()
|
23 |
args.segmenter = "huge"
|
@@ -30,6 +30,8 @@ else:
|
|
30 |
|
31 |
shared_captioner = build_captioner(args.captioner, args.device, args)
|
32 |
shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
|
|
|
|
|
33 |
tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
|
34 |
shared_chatbot_tools = build_chatbot_tools(tools_dict)
|
35 |
|
@@ -57,13 +59,13 @@ class ImageSketcher(gr.Image):
|
|
57 |
return super().preprocess(x)
|
58 |
|
59 |
|
60 |
-
def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None,
|
61 |
session_id=None):
|
62 |
segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
|
63 |
captioner = captioner
|
64 |
if session_id is not None:
|
65 |
print('Init caption anything for session {}'.format(session_id))
|
66 |
-
return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, text_refiner=text_refiner)
|
67 |
|
68 |
|
69 |
def init_openai_api_key(api_key=""):
|
@@ -146,6 +148,7 @@ def upload_callback(image_input, state, visual_chatgpt=None):
|
|
146 |
api_key="",
|
147 |
captioner=shared_captioner,
|
148 |
sam_model=shared_sam_model,
|
|
|
149 |
session_id=iface.app_id
|
150 |
)
|
151 |
model.segmenter.set_image(image_input)
|
@@ -154,6 +157,7 @@ def upload_callback(image_input, state, visual_chatgpt=None):
|
|
154 |
input_size = model.input_size
|
155 |
|
156 |
if visual_chatgpt is not None:
|
|
|
157 |
new_image_path = get_new_image_name('chat_image', func_name='upload')
|
158 |
image_input.save(new_image_path)
|
159 |
visual_chatgpt.current_image = new_image_path
|
@@ -192,6 +196,7 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
|
|
192 |
api_key="",
|
193 |
captioner=shared_captioner,
|
194 |
sam_model=shared_sam_model,
|
|
|
195 |
text_refiner=text_refiner,
|
196 |
session_id=iface.app_id
|
197 |
)
|
@@ -213,6 +218,7 @@ def inference_click(image_input, point_prompt, click_mode, enable_wiki, language
|
|
213 |
x, y = input_points[-1]
|
214 |
|
215 |
if visual_chatgpt is not None:
|
|
|
216 |
new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
|
217 |
Image.open(out["crop_save_path"]).save(new_crop_save_path)
|
218 |
point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
|
@@ -273,6 +279,7 @@ def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuali
|
|
273 |
api_key="",
|
274 |
captioner=shared_captioner,
|
275 |
sam_model=shared_sam_model,
|
|
|
276 |
text_refiner=text_refiner,
|
277 |
session_id=iface.app_id
|
278 |
)
|
@@ -325,6 +332,7 @@ def cap_everything(image_input, visual_chatgpt, text_refiner):
|
|
325 |
api_key="",
|
326 |
captioner=shared_captioner,
|
327 |
sam_model=shared_sam_model,
|
|
|
328 |
text_refiner=text_refiner,
|
329 |
session_id=iface.app_id
|
330 |
)
|
|
|
17 |
from caption_anything.segmenter import build_segmenter
|
18 |
from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
|
19 |
from segment_anything import sam_model_registry
|
20 |
+
import easyocr
|
21 |
|
22 |
args = parse_augment()
|
23 |
args.segmenter = "huge"
|
|
|
30 |
|
31 |
shared_captioner = build_captioner(args.captioner, args.device, args)
|
32 |
shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
|
33 |
+
ocr_lang = ["ch_tra", "en"]
|
34 |
+
shared_ocr_reader = easyocr.Reader(ocr_lang)
|
35 |
tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
|
36 |
shared_chatbot_tools = build_chatbot_tools(tools_dict)
|
37 |
|
|
|
59 |
return super().preprocess(x)
|
60 |
|
61 |
|
62 |
+
def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, ocr_reader=None, text_refiner=None,
|
63 |
session_id=None):
|
64 |
segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
|
65 |
captioner = captioner
|
66 |
if session_id is not None:
|
67 |
print('Init caption anything for session {}'.format(session_id))
|
68 |
+
return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, ocr_reader=ocr_reader, text_refiner=text_refiner)
|
69 |
|
70 |
|
71 |
def init_openai_api_key(api_key=""):
|
|
|
148 |
api_key="",
|
149 |
captioner=shared_captioner,
|
150 |
sam_model=shared_sam_model,
|
151 |
+
ocr_reader=shared_ocr_reader,
|
152 |
session_id=iface.app_id
|
153 |
)
|
154 |
model.segmenter.set_image(image_input)
|
|
|
157 |
input_size = model.input_size
|
158 |
|
159 |
if visual_chatgpt is not None:
|
160 |
+
print('upload_callback: add caption to chatGPT memory')
|
161 |
new_image_path = get_new_image_name('chat_image', func_name='upload')
|
162 |
image_input.save(new_image_path)
|
163 |
visual_chatgpt.current_image = new_image_path
|
|
|
196 |
api_key="",
|
197 |
captioner=shared_captioner,
|
198 |
sam_model=shared_sam_model,
|
199 |
+
ocr_reader=shared_ocr_reader,
|
200 |
text_refiner=text_refiner,
|
201 |
session_id=iface.app_id
|
202 |
)
|
|
|
218 |
x, y = input_points[-1]
|
219 |
|
220 |
if visual_chatgpt is not None:
|
221 |
+
print('inference_click: add caption to chatGPT memory')
|
222 |
new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
|
223 |
Image.open(out["crop_save_path"]).save(new_crop_save_path)
|
224 |
point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
|
|
|
279 |
api_key="",
|
280 |
captioner=shared_captioner,
|
281 |
sam_model=shared_sam_model,
|
282 |
+
ocr_reader=shared_ocr_reader,
|
283 |
text_refiner=text_refiner,
|
284 |
session_id=iface.app_id
|
285 |
)
|
|
|
332 |
api_key="",
|
333 |
captioner=shared_captioner,
|
334 |
sam_model=shared_sam_model,
|
335 |
+
ocr_reader=shared_ocr_reader,
|
336 |
text_refiner=text_refiner,
|
337 |
session_id=iface.app_id
|
338 |
)
|
caption_anything/captioner/blip2.py
CHANGED
@@ -6,6 +6,7 @@ from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
|
6 |
|
7 |
from caption_anything.utils.utils import is_platform_win, load_image
|
8 |
from .base_captioner import BaseCaptioner
|
|
|
9 |
|
10 |
class BLIP2Captioner(BaseCaptioner):
|
11 |
def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
|
@@ -33,8 +34,7 @@ class BLIP2Captioner(BaseCaptioner):
|
|
33 |
if not self.dialogue:
|
34 |
inputs = self.processor(image, text = args['text_prompt'], return_tensors="pt").to(self.device, self.torch_dtype)
|
35 |
out = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True, max_new_tokens=50)
|
36 |
-
|
37 |
-
caption = [caption.strip() for caption in captions][0]
|
38 |
if self.enable_filter and filter:
|
39 |
print('reference caption: {}, caption: {}'.format(args['reference_caption'], caption))
|
40 |
clip_score = self.filter_caption(image, caption, args['reference_caption'])
|
|
|
6 |
|
7 |
from caption_anything.utils.utils import is_platform_win, load_image
|
8 |
from .base_captioner import BaseCaptioner
|
9 |
+
import time
|
10 |
|
11 |
class BLIP2Captioner(BaseCaptioner):
|
12 |
def __init__(self, device, dialogue: bool = False, enable_filter: bool = False):
|
|
|
34 |
if not self.dialogue:
|
35 |
inputs = self.processor(image, text = args['text_prompt'], return_tensors="pt").to(self.device, self.torch_dtype)
|
36 |
out = self.model.generate(**inputs, return_dict_in_generate=True, output_scores=True, max_new_tokens=50)
|
37 |
+
caption = self.processor.decode(out.sequences[0], skip_special_tokens=True).strip()
|
|
|
38 |
if self.enable_filter and filter:
|
39 |
print('reference caption: {}, caption: {}'.format(args['reference_caption'], caption))
|
40 |
clip_score = self.filter_caption(image, caption, args['reference_caption'])
|
caption_anything/model.py
CHANGED
@@ -8,6 +8,7 @@ import numpy as np
|
|
8 |
from PIL import Image
|
9 |
import easyocr
|
10 |
import copy
|
|
|
11 |
from caption_anything.captioner import build_captioner, BaseCaptioner
|
12 |
from caption_anything.segmenter import build_segmenter, build_segmenter_densecap
|
13 |
from caption_anything.text_refiner import build_text_refiner
|
@@ -16,14 +17,15 @@ from caption_anything.utils.utils import mask_painter_foreground_all, mask_paint
|
|
16 |
from caption_anything.utils.densecap_painter import draw_bbox
|
17 |
|
18 |
class CaptionAnything:
|
19 |
-
def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
|
20 |
self.args = args
|
21 |
self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
|
22 |
self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
|
23 |
self.segmenter_densecap = build_segmenter_densecap(args.segmenter, args.device, args, model=self.segmenter.model)
|
|
|
|
|
24 |
|
25 |
-
|
26 |
-
self.reader = easyocr.Reader(self.lang)
|
27 |
self.text_refiner = None
|
28 |
if not args.disable_gpt:
|
29 |
if text_refiner is not None:
|
@@ -31,6 +33,7 @@ class CaptionAnything:
|
|
31 |
elif api_key != "":
|
32 |
self.init_refiner(api_key)
|
33 |
self.require_caption_prompt = args.captioner == 'blip2'
|
|
|
34 |
|
35 |
@property
|
36 |
def image_embedding(self):
|
@@ -213,7 +216,7 @@ class CaptionAnything:
|
|
213 |
def parse_ocr(self, image, thres=0.2):
|
214 |
width, height = get_image_shape(image)
|
215 |
image = load_image(image, return_type='numpy')
|
216 |
-
bounds = self.
|
217 |
bounds = [bound for bound in bounds if bound[2] > thres]
|
218 |
print('Process OCR Text:\n', bounds)
|
219 |
|
@@ -257,7 +260,7 @@ class CaptionAnything:
|
|
257 |
if __name__ == "__main__":
|
258 |
from caption_anything.utils.parser import parse_augment
|
259 |
args = parse_augment()
|
260 |
-
image_path = '
|
261 |
image = Image.open(image_path)
|
262 |
prompts = [
|
263 |
{
|
|
|
8 |
from PIL import Image
|
9 |
import easyocr
|
10 |
import copy
|
11 |
+
import time
|
12 |
from caption_anything.captioner import build_captioner, BaseCaptioner
|
13 |
from caption_anything.segmenter import build_segmenter, build_segmenter_densecap
|
14 |
from caption_anything.text_refiner import build_text_refiner
|
|
|
17 |
from caption_anything.utils.densecap_painter import draw_bbox
|
18 |
|
19 |
class CaptionAnything:
|
20 |
+
def __init__(self, args, api_key="", captioner=None, segmenter=None, ocr_reader=None, text_refiner=None):
|
21 |
self.args = args
|
22 |
self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
|
23 |
self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
|
24 |
self.segmenter_densecap = build_segmenter_densecap(args.segmenter, args.device, args, model=self.segmenter.model)
|
25 |
+
self.ocr_lang = ["ch_tra", "en"]
|
26 |
+
self.ocr_reader = ocr_reader if ocr_reader is not None else easyocr.Reader(self.ocr_lang)
|
27 |
|
28 |
+
|
|
|
29 |
self.text_refiner = None
|
30 |
if not args.disable_gpt:
|
31 |
if text_refiner is not None:
|
|
|
33 |
elif api_key != "":
|
34 |
self.init_refiner(api_key)
|
35 |
self.require_caption_prompt = args.captioner == 'blip2'
|
36 |
+
print('text_refiner init time: ', time.time() - t0)
|
37 |
|
38 |
@property
|
39 |
def image_embedding(self):
|
|
|
216 |
def parse_ocr(self, image, thres=0.2):
|
217 |
width, height = get_image_shape(image)
|
218 |
image = load_image(image, return_type='numpy')
|
219 |
+
bounds = self.ocr_reader.readtext(image)
|
220 |
bounds = [bound for bound in bounds if bound[2] > thres]
|
221 |
print('Process OCR Text:\n', bounds)
|
222 |
|
|
|
260 |
if __name__ == "__main__":
|
261 |
from caption_anything.utils.parser import parse_augment
|
262 |
args = parse_augment()
|
263 |
+
image_path = 'result/wt/memes/87226084.jpg'
|
264 |
image = Image.open(image_path)
|
265 |
prompts = [
|
266 |
{
|