Li Zhaoxu commited on
Commit
8216227
·
1 Parent(s): 9e21376
Files changed (3) hide show
  1. app.py +558 -77
  2. appv1.py +153 -0
  3. eval_configs/tinygptv_stage4_eval.yaml +2 -2
app.py CHANGED
@@ -1,15 +1,24 @@
1
  import argparse
2
  import os
3
  import random
 
 
 
 
4
 
5
  import numpy as np
 
6
  import torch
 
7
  import gradio as gr
 
 
8
  import torch.backends.cudnn as cudnn
 
9
  from minigpt4.common.config import Config
10
- from minigpt4.common.dist_utils import get_rank
11
  from minigpt4.common.registry import registry
12
- from minigpt4.conversation.conversation import Chat, CONV_VISION
13
 
14
  # imports modules for registration
15
  from minigpt4.datasets.builders import *
@@ -18,136 +27,608 @@ from minigpt4.processors import *
18
  from minigpt4.runners import *
19
  from minigpt4.tasks import *
20
 
 
21
  def parse_args():
22
  parser = argparse.ArgumentParser(description="Demo")
23
- parser.add_argument("--cfg-path", type=str, default='eval_configs/tinygptv_stage1_2_3_eval.yaml', help="path to configuration file.")
 
 
24
  parser.add_argument(
25
  "--options",
26
  nargs="+",
27
  help="override some settings in the used config, the key-value pair "
28
- "in xxx=yyy format will be merged into config file (deprecate), "
29
- "change to --cfg-options instead.",
30
  )
31
  args = parser.parse_args()
32
  return args
33
 
34
 
35
- def setup_seeds(config):
36
- seed = config.run_cfg.seed + get_rank()
37
-
38
- random.seed(seed)
39
- np.random.seed(seed)
40
- torch.manual_seed(seed)
41
-
42
- cudnn.benchmark = False
43
- cudnn.deterministic = True
44
- # ========================================
45
- # Model Initialization
46
- # ========================================
47
-
48
- SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue.
49
 
50
- You can duplicate and use it with a paid private GPU.
51
-
52
- <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Vision-CAIR/minigpt4?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
53
-
54
- Alternatively, you can also use the demo on our [project page](https://minigpt-4.github.io).
55
- '''
56
 
57
  print('Initializing Chat')
58
- cfg = Config(parse_args())
 
 
 
59
 
60
  model_config = cfg.model_cfg
 
61
  model_cls = registry.get_model_class(model_config.arch)
62
- model = model_cls.from_config(model_config).to('cuda:0')
 
63
 
64
  vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
65
  vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
66
- chat = Chat(model, vis_processor)
67
- print('Initialization Finished')
68
 
69
- # ========================================
70
- # Gradio Setting
71
- # ========================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def gradio_reset(chat_state, img_list):
74
  if chat_state is not None:
75
  chat_state.messages = []
76
  if img_list is not None:
77
  img_list = []
78
- return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False), gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- def upload_img(gr_img, text_input, chat_state):
81
- if gr_img is None:
82
- return None, None, gr.update(interactive=True), chat_state, None
83
- chat_state = CONV_VISION.copy()
84
- img_list = []
85
- llm_message = chat.upload_img(gr_img, chat_state, img_list)
86
 
87
- return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
88
 
89
- def gradio_ask(user_message, chatbot, chat_state):
90
  if len(user_message) == 0:
91
- return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  chat.ask(user_message, chat_state)
 
93
  chatbot = chatbot + [[user_message, None]]
94
- return '', chatbot, chat_state
95
 
 
 
 
 
 
 
 
96
 
97
- def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
98
- llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature, max_length=2000)[0]
 
 
 
 
 
99
  chatbot[-1][1] = llm_message
100
- return chatbot, chat_state, img_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- title = """<h1 align="center">Demo of TinyGPT-V</h1>"""
103
- description = """<h3>This is the demo of TinyGPT-V. Upload your images and start chatting!</h3>"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  article = """<div style='display:flex; gap: 0.25rem; '><a href='https://github.com/DLYuanGod/TinyGPT-V'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://arxiv.org/abs/2312.16862'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
105
  """
106
 
107
- #TODO show examples below
 
 
 
 
 
 
 
 
 
 
108
 
 
 
109
  with gr.Blocks() as demo:
110
  gr.Markdown(title)
111
-
112
- gr.Markdown(description)
113
  gr.Markdown(article)
114
 
115
  with gr.Row():
116
  with gr.Column(scale=0.5):
117
- image = gr.Image(type="pil")
118
- upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
119
- clear = gr.Button("Restart")
120
-
121
- num_beams = gr.Slider(
122
- minimum=1,
123
- maximum=5,
124
- value=1,
125
- step=1,
126
- interactive=True,
127
- label="beam search numbers)",
128
- )
129
-
130
  temperature = gr.Slider(
131
  minimum=0.1,
132
- maximum=2.0,
133
- value=1.0,
134
  step=0.1,
135
  interactive=True,
136
  label="Temperature",
137
  )
138
-
 
 
 
139
 
140
  with gr.Column():
141
- chat_state = gr.State()
142
- img_list = gr.State()
143
  chatbot = gr.Chatbot(label='TinyGPT-V')
144
- text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
145
-
146
- upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
147
-
148
- text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
149
- gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  )
151
- clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
152
 
153
- demo.launch(enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import argparse
2
  import os
3
  import random
4
+ from collections import defaultdict
5
+
6
+ import cv2
7
+ import re
8
 
9
  import numpy as np
10
+ from PIL import Image
11
  import torch
12
+ import html
13
  import gradio as gr
14
+
15
+ import torchvision.transforms as T
16
  import torch.backends.cudnn as cudnn
17
+
18
  from minigpt4.common.config import Config
19
+
20
  from minigpt4.common.registry import registry
21
+ from minigpt4.conversation.conversation import Conversation, SeparatorStyle, Chat
22
 
23
  # imports modules for registration
24
  from minigpt4.datasets.builders import *
 
27
  from minigpt4.runners import *
28
  from minigpt4.tasks import *
29
 
30
+
31
  def parse_args():
32
  parser = argparse.ArgumentParser(description="Demo")
33
+ parser.add_argument("--cfg-path", default='eval_configs/tinygptv_stage4_eval.yaml',
34
+ help="path to configuration file.")
35
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
36
  parser.add_argument(
37
  "--options",
38
  nargs="+",
39
  help="override some settings in the used config, the key-value pair "
40
+ "in xxx=yyy format will be merged into config file (deprecate), "
41
+ "change to --cfg-options instead.",
42
  )
43
  args = parser.parse_args()
44
  return args
45
 
46
 
47
+ random.seed(42)
48
+ np.random.seed(42)
49
+ torch.manual_seed(42)
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ cudnn.benchmark = False
52
+ cudnn.deterministic = True
 
 
 
 
53
 
54
  print('Initializing Chat')
55
+ args = parse_args()
56
+ cfg = Config(args)
57
+
58
+ device = 'cuda:{}'.format(args.gpu_id)
59
 
60
  model_config = cfg.model_cfg
61
+ model_config.device_8bit = args.gpu_id
62
  model_cls = registry.get_model_class(model_config.arch)
63
+ model = model_cls.from_config(model_config).to(device)
64
+ bounding_box_size = 100
65
 
66
  vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
67
  vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
 
 
68
 
69
+ model = model.eval()
70
+
71
+ CONV_VISION = Conversation(
72
+ system="",
73
+ roles=(r"<s>[INST] ", r" [/INST]"),
74
+ messages=[],
75
+ offset=2,
76
+ sep_style=SeparatorStyle.SINGLE,
77
+ sep="",
78
+ )
79
+
80
+
81
+ def extract_substrings(string):
82
+ # first check if there is no-finished bracket
83
+ index = string.rfind('}')
84
+ if index != -1:
85
+ string = string[:index + 1]
86
+
87
+ pattern = r'<p>(.*?)\}(?!<)'
88
+ matches = re.findall(pattern, string)
89
+ substrings = [match for match in matches]
90
+
91
+ return substrings
92
+
93
+
94
+ def is_overlapping(rect1, rect2):
95
+ x1, y1, x2, y2 = rect1
96
+ x3, y3, x4, y4 = rect2
97
+ return not (x2 < x3 or x1 > x4 or y2 < y3 or y1 > y4)
98
+
99
+
100
+ def computeIoU(bbox1, bbox2):
101
+ x1, y1, x2, y2 = bbox1
102
+ x3, y3, x4, y4 = bbox2
103
+ intersection_x1 = max(x1, x3)
104
+ intersection_y1 = max(y1, y3)
105
+ intersection_x2 = min(x2, x4)
106
+ intersection_y2 = min(y2, y4)
107
+ intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1)
108
+ bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1)
109
+ bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1)
110
+ union_area = bbox1_area + bbox2_area - intersection_area
111
+ iou = intersection_area / union_area
112
+ return iou
113
+
114
+
115
+ def save_tmp_img(visual_img):
116
+ file_name = "".join([str(random.randint(0, 9)) for _ in range(5)]) + ".jpg"
117
+ file_path = "/tmp/gradio" + file_name
118
+ visual_img.save(file_path)
119
+ return file_path
120
+
121
+
122
+ def mask2bbox(mask):
123
+ if mask is None:
124
+ return ''
125
+ mask = mask.resize([100, 100], resample=Image.NEAREST)
126
+ mask = np.array(mask)[:, :, 0]
127
+
128
+ rows = np.any(mask, axis=1)
129
+ cols = np.any(mask, axis=0)
130
+
131
+ if rows.sum():
132
+ # Get the top, bottom, left, and right boundaries
133
+ rmin, rmax = np.where(rows)[0][[0, -1]]
134
+ cmin, cmax = np.where(cols)[0][[0, -1]]
135
+ bbox = '{{<{}><{}><{}><{}>}}'.format(cmin, rmin, cmax, rmax)
136
+ else:
137
+ bbox = ''
138
+
139
+ return bbox
140
+
141
+
142
+ def escape_markdown(text):
143
+ # List of Markdown special characters that need to be escaped
144
+ md_chars = ['<', '>']
145
+
146
+ # Escape each special character
147
+ for char in md_chars:
148
+ text = text.replace(char, '\\' + char)
149
+
150
+ return text
151
+
152
+
153
+ def reverse_escape(text):
154
+ md_chars = ['\\<', '\\>']
155
+
156
+ for char in md_chars:
157
+ text = text.replace(char, char[1:])
158
+
159
+ return text
160
+
161
+
162
+ colors = [
163
+ (255, 0, 0),
164
+ (0, 255, 0),
165
+ (0, 0, 255),
166
+ (210, 210, 0),
167
+ (255, 0, 255),
168
+ (0, 255, 255),
169
+ (114, 128, 250),
170
+ (0, 165, 255),
171
+ (0, 128, 0),
172
+ (144, 238, 144),
173
+ (238, 238, 175),
174
+ (255, 191, 0),
175
+ (0, 128, 0),
176
+ (226, 43, 138),
177
+ (255, 0, 255),
178
+ (0, 215, 255),
179
+ ]
180
+
181
+ color_map = {
182
+ f"{color_id}": f"#{hex(color[2])[2:].zfill(2)}{hex(color[1])[2:].zfill(2)}{hex(color[0])[2:].zfill(2)}" for
183
+ color_id, color in enumerate(colors)
184
+ }
185
+
186
+ used_colors = colors
187
+
188
+
189
+ def visualize_all_bbox_together(image, generation):
190
+ if image is None:
191
+ return None, ''
192
+
193
+ generation = html.unescape(generation)
194
+
195
+ image_width, image_height = image.size
196
+ image = image.resize([500, int(500 / image_width * image_height)])
197
+ image_width, image_height = image.size
198
+
199
+ string_list = extract_substrings(generation)
200
+ if string_list: # it is grounding or detection
201
+ mode = 'all'
202
+ entities = defaultdict(list)
203
+ i = 0
204
+ j = 0
205
+ for string in string_list:
206
+ try:
207
+ obj, string = string.split('</p>')
208
+ except ValueError:
209
+ print('wrong string: ', string)
210
+ continue
211
+ bbox_list = string.split('<delim>')
212
+ flag = False
213
+ for bbox_string in bbox_list:
214
+ integers = re.findall(r'-?\d+', bbox_string)
215
+ if len(integers) == 4:
216
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
217
+ left = x0 / bounding_box_size * image_width
218
+ bottom = y0 / bounding_box_size * image_height
219
+ right = x1 / bounding_box_size * image_width
220
+ top = y1 / bounding_box_size * image_height
221
+
222
+ entities[obj].append([left, bottom, right, top])
223
+
224
+ j += 1
225
+ flag = True
226
+ if flag:
227
+ i += 1
228
+ else:
229
+ integers = re.findall(r'-?\d+', generation)
230
+
231
+ if len(integers) == 4: # it is refer
232
+ mode = 'single'
233
+
234
+ entities = list()
235
+ x0, y0, x1, y1 = int(integers[0]), int(integers[1]), int(integers[2]), int(integers[3])
236
+ left = x0 / bounding_box_size * image_width
237
+ bottom = y0 / bounding_box_size * image_height
238
+ right = x1 / bounding_box_size * image_width
239
+ top = y1 / bounding_box_size * image_height
240
+ entities.append([left, bottom, right, top])
241
+ else:
242
+ # don't detect any valid bbox to visualize
243
+ return None, ''
244
+
245
+ if len(entities) == 0:
246
+ return None, ''
247
+
248
+ if isinstance(image, Image.Image):
249
+ image_h = image.height
250
+ image_w = image.width
251
+ image = np.array(image)
252
+
253
+ elif isinstance(image, str):
254
+ if os.path.exists(image):
255
+ pil_img = Image.open(image).convert("RGB")
256
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
257
+ image_h = pil_img.height
258
+ image_w = pil_img.width
259
+ else:
260
+ raise ValueError(f"invaild image path, {image}")
261
+ elif isinstance(image, torch.Tensor):
262
+
263
+ image_tensor = image.cpu()
264
+ reverse_norm_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073])[:, None, None]
265
+ reverse_norm_std = torch.tensor([0.26862954, 0.26130258, 0.27577711])[:, None, None]
266
+ image_tensor = image_tensor * reverse_norm_std + reverse_norm_mean
267
+ pil_img = T.ToPILImage()(image_tensor)
268
+ image_h = pil_img.height
269
+ image_w = pil_img.width
270
+ image = np.array(pil_img)[:, :, [2, 1, 0]]
271
+ else:
272
+ raise ValueError(f"invaild image format, {type(image)} for {image}")
273
+
274
+ indices = list(range(len(entities)))
275
+
276
+ new_image = image.copy()
277
+
278
+ previous_bboxes = []
279
+ # size of text
280
+ text_size = 0.5
281
+ # thickness of text
282
+ text_line = 1 # int(max(1 * min(image_h, image_w) / 512, 1))
283
+ box_line = 2
284
+ (c_width, text_height), _ = cv2.getTextSize("F", cv2.FONT_HERSHEY_COMPLEX, text_size, text_line)
285
+ base_height = int(text_height * 0.675)
286
+ text_offset_original = text_height - base_height
287
+ text_spaces = 2
288
+
289
+ # num_bboxes = sum(len(x[-1]) for x in entities)
290
+ used_colors = colors # random.sample(colors, k=num_bboxes)
291
+
292
+ color_id = -1
293
+ for entity_idx, entity_name in enumerate(entities):
294
+ if mode == 'single' or mode == 'identify':
295
+ bboxes = entity_name
296
+ bboxes = [bboxes]
297
+ else:
298
+ bboxes = entities[entity_name]
299
+ color_id += 1
300
+ for bbox_id, (x1_norm, y1_norm, x2_norm, y2_norm) in enumerate(bboxes):
301
+ skip_flag = False
302
+ orig_x1, orig_y1, orig_x2, orig_y2 = int(x1_norm), int(y1_norm), int(x2_norm), int(y2_norm)
303
+
304
+ color = used_colors[entity_idx % len(used_colors)] # tuple(np.random.randint(0, 255, size=3).tolist())
305
+ new_image = cv2.rectangle(new_image, (orig_x1, orig_y1), (orig_x2, orig_y2), color, box_line)
306
+
307
+ if mode == 'all':
308
+ l_o, r_o = box_line // 2 + box_line % 2, box_line // 2 + box_line % 2 + 1
309
+
310
+ x1 = orig_x1 - l_o
311
+ y1 = orig_y1 - l_o
312
+
313
+ if y1 < text_height + text_offset_original + 2 * text_spaces:
314
+ y1 = orig_y1 + r_o + text_height + text_offset_original + 2 * text_spaces
315
+ x1 = orig_x1 + r_o
316
+
317
+ # add text background
318
+ (text_width, text_height), _ = cv2.getTextSize(f" {entity_name}", cv2.FONT_HERSHEY_COMPLEX, text_size,
319
+ text_line)
320
+ text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2 = x1, y1 - (
321
+ text_height + text_offset_original + 2 * text_spaces), x1 + text_width, y1
322
+
323
+ for prev_bbox in previous_bboxes:
324
+ if computeIoU((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']) > 0.95 and \
325
+ prev_bbox['phrase'] == entity_name:
326
+ skip_flag = True
327
+ break
328
+ while is_overlapping((text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), prev_bbox['bbox']):
329
+ text_bg_y1 += (text_height + text_offset_original + 2 * text_spaces)
330
+ text_bg_y2 += (text_height + text_offset_original + 2 * text_spaces)
331
+ y1 += (text_height + text_offset_original + 2 * text_spaces)
332
+
333
+ if text_bg_y2 >= image_h:
334
+ text_bg_y1 = max(0, image_h - (text_height + text_offset_original + 2 * text_spaces))
335
+ text_bg_y2 = image_h
336
+ y1 = image_h
337
+ break
338
+ if not skip_flag:
339
+ alpha = 0.5
340
+ for i in range(text_bg_y1, text_bg_y2):
341
+ for j in range(text_bg_x1, text_bg_x2):
342
+ if i < image_h and j < image_w:
343
+ if j < text_bg_x1 + 1.35 * c_width:
344
+ # original color
345
+ bg_color = color
346
+ else:
347
+ # white
348
+ bg_color = [255, 255, 255]
349
+ new_image[i, j] = (alpha * new_image[i, j] + (1 - alpha) * np.array(bg_color)).astype(
350
+ np.uint8)
351
+
352
+ cv2.putText(
353
+ new_image, f" {entity_name}", (x1, y1 - text_offset_original - 1 * text_spaces),
354
+ cv2.FONT_HERSHEY_COMPLEX, text_size, (0, 0, 0), text_line, cv2.LINE_AA
355
+ )
356
+
357
+ previous_bboxes.append(
358
+ {'bbox': (text_bg_x1, text_bg_y1, text_bg_x2, text_bg_y2), 'phrase': entity_name})
359
+
360
+ if mode == 'all':
361
+ def color_iterator(colors):
362
+ while True:
363
+ for color in colors:
364
+ yield color
365
+
366
+ color_gen = color_iterator(colors)
367
+
368
+ # Add colors to phrases and remove <p></p>
369
+ def colored_phrases(match):
370
+ phrase = match.group(1)
371
+ color = next(color_gen)
372
+ return f'<span style="color:rgb{color}">{phrase}</span>'
373
+
374
+ generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
375
+ generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
376
+ else:
377
+ generation_colored = ''
378
+
379
+ pil_image = Image.fromarray(new_image)
380
+ return pil_image, generation_colored
381
+
382
 
383
  def gradio_reset(chat_state, img_list):
384
  if chat_state is not None:
385
  chat_state.messages = []
386
  if img_list is not None:
387
  img_list = []
388
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Upload your image and chat',
389
+ interactive=True), chat_state, img_list
390
+
391
+
392
+ def image_upload_trigger(upload_flag, replace_flag, img_list):
393
+ # set the upload flag to true when receive a new image.
394
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
395
+ upload_flag = 1
396
+ if img_list:
397
+ replace_flag = 1
398
+ return upload_flag, replace_flag
399
+
400
+
401
+ def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
402
+ # set the upload flag to true when receive a new image.
403
+ # if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
404
+ upload_flag = 1
405
+ if img_list or replace_flag == 1:
406
+ replace_flag = 1
407
 
408
+ return upload_flag, replace_flag
 
 
 
 
 
409
 
 
410
 
411
+ def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
412
  if len(user_message) == 0:
413
+ text_box_show = 'Input should not be empty!'
414
+ else:
415
+ text_box_show = ''
416
+
417
+ if isinstance(gr_img, dict):
418
+ gr_img, mask = gr_img['image'], gr_img['mask']
419
+ else:
420
+ mask = None
421
+
422
+ if '[identify]' in user_message:
423
+ # check if user provide bbox in the text input
424
+ integers = re.findall(r'-?\d+', user_message)
425
+ if len(integers) != 4: # no bbox in text
426
+ bbox = mask2bbox(mask)
427
+ user_message = user_message + bbox
428
+
429
+ if chat_state is None:
430
+ chat_state = CONV_VISION.copy()
431
+
432
+ if upload_flag:
433
+ if replace_flag:
434
+ chat_state = CONV_VISION.copy() # new image, reset everything
435
+ replace_flag = 0
436
+ chatbot = []
437
+ img_list = []
438
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
439
+ upload_flag = 0
440
+
441
  chat.ask(user_message, chat_state)
442
+
443
  chatbot = chatbot + [[user_message, None]]
 
444
 
445
+ if '[identify]' in user_message:
446
+ visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
447
+ if visual_img is not None:
448
+ file_path = save_tmp_img(visual_img)
449
+ chatbot = chatbot + [[(file_path,), None]]
450
+
451
+ return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
452
 
453
+
454
+ def gradio_answer(chatbot, chat_state, img_list, temperature):
455
+ llm_message = chat.answer(conv=chat_state,
456
+ img_list=img_list,
457
+ temperature=temperature,
458
+ max_new_tokens=500,
459
+ max_length=2000)[0]
460
  chatbot[-1][1] = llm_message
461
+ return chatbot, chat_state
462
+
463
+
464
+ def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
465
+ if len(img_list) > 0:
466
+ if not isinstance(img_list[0], torch.Tensor):
467
+ chat.encode_img(img_list)
468
+ streamer = chat.stream_answer(conv=chat_state,
469
+ img_list=img_list,
470
+ temperature=temperature,
471
+ max_new_tokens=500,
472
+ max_length=2000)
473
+ output = ''
474
+ for new_output in streamer:
475
+ if '###' in new_output:
476
+
477
+ new_output = new_output.split('###')[0]
478
+ output += escape_markdown(new_output)
479
+ chatbot[-1][1] = output
480
+
481
+ yield chatbot, chat_state
482
+ break
483
+ escapped = escape_markdown(new_output)
484
+ output += escapped
485
+ chatbot[-1][1] = output
486
+ yield chatbot, chat_state
487
+
488
+ return chatbot, chat_state
489
+
490
 
491
+ def gradio_visualize(chatbot, gr_img):
492
+ if isinstance(gr_img, dict):
493
+ gr_img, mask = gr_img['image'], gr_img['mask']
494
+
495
+ unescaped = reverse_escape(chatbot[-1][1])
496
+ visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
497
+ if visual_img is not None:
498
+ if len(generation_color):
499
+ chatbot[-1][1] = generation_color
500
+ file_path = save_tmp_img(visual_img)
501
+ chatbot = chatbot + [[None, (file_path,)]]
502
+
503
+ return chatbot
504
+
505
+
506
+ def gradio_taskselect(idx):
507
+ prompt_list = [
508
+ '',
509
+ '[grounding] describe this image in detail',
510
+ '[refer] ',
511
+ '[detection] ',
512
+ '[identify] what is this ',
513
+ '[vqa] '
514
+ ]
515
+ instruct_list = [
516
+ '**Hint:** Type in whatever you want',
517
+ '**Hint:** Send the command to generate a grounded image description',
518
+ '**Hint:** Type in a phrase about an object in the image and send the command',
519
+ '**Hint:** Type in a caption or phrase, and see object locations in the image',
520
+ '**Hint:** Draw a bounding box on the uploaded image then send the command. Click the "clear" botton on the top right of the image before redraw',
521
+ '**Hint:** Send a question to get a short answer',
522
+ ]
523
+ return prompt_list[idx], instruct_list[idx]
524
+
525
+
526
+
527
+
528
+ chat = Chat(model, vis_processor, device=device)
529
+
530
+ title = """<h1 align="center">TinyGPT-V Demo</h1>"""
531
+ description = 'Welcome to Our TinyGPT-V Chatbot Demo!'
532
  article = """<div style='display:flex; gap: 0.25rem; '><a href='https://github.com/DLYuanGod/TinyGPT-V'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://arxiv.org/abs/2312.16862'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
533
  """
534
 
535
+ introduction = '''
536
+ For Abilities Involving Visual Grounding:
537
+ 1. Grounding: CLICK **Send** to generate a grounded image description.
538
+ 2. Refer: Input a referring object and CLICK **Send**.
539
+ 3. Detection: Write a caption or phrase, and CLICK **Send**.
540
+ 4. Identify: Draw the bounding box on the uploaded image window and CLICK **Send** to generate the bounding box. (CLICK "clear" button before re-drawing next time).
541
+ 5. VQA: Input a visual question and CLICK **Send**.
542
+ 6. No Tag: Input whatever you want and CLICK **Send** without any tagging
543
+
544
+ You can also simply chat in free form!
545
+ '''
546
 
547
+ text_input = gr.Textbox(placeholder='Upload your image and chat', interactive=True, show_label=False, container=False,
548
+ scale=8)
549
  with gr.Blocks() as demo:
550
  gr.Markdown(title)
551
+ # gr.Markdown(description)
 
552
  gr.Markdown(article)
553
 
554
  with gr.Row():
555
  with gr.Column(scale=0.5):
556
+ image = gr.Image(type="pil", tool='sketch', brush_radius=20)
557
+
 
 
 
 
 
 
 
 
 
 
 
558
  temperature = gr.Slider(
559
  minimum=0.1,
560
+ maximum=1.5,
561
+ value=0.6,
562
  step=0.1,
563
  interactive=True,
564
  label="Temperature",
565
  )
566
+
567
+ clear = gr.Button("Restart")
568
+
569
+ gr.Markdown(introduction)
570
 
571
  with gr.Column():
572
+ chat_state = gr.State(value=None)
573
+ img_list = gr.State(value=[])
574
  chatbot = gr.Chatbot(label='TinyGPT-V')
575
+
576
+ dataset = gr.Dataset(
577
+ components=[gr.Textbox(visible=False)],
578
+ samples=[['No Tag'], ['Grounding'], ['Refer'], ['Detection'], ['Identify'], ['VQA']],
579
+ type="index",
580
+ label='Task Shortcuts',
581
+ )
582
+ task_inst = gr.Markdown('**Hint:** Upload your image and chat')
583
+ with gr.Row():
584
+ text_input.render()
585
+ send = gr.Button("Send", variant='primary', size='sm', scale=1)
586
+
587
+ upload_flag = gr.State(value=0)
588
+ replace_flag = gr.State(value=0)
589
+ image.upload(image_upload_trigger, [upload_flag, replace_flag, img_list], [upload_flag, replace_flag])
590
+
591
+
592
+
593
+ dataset.click(
594
+ gradio_taskselect,
595
+ inputs=[dataset],
596
+ outputs=[text_input, task_inst],
597
+ show_progress="hidden",
598
+ postprocess=False,
599
+ queue=False,
600
+ )
601
+
602
+ text_input.submit(
603
+ gradio_ask,
604
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
605
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
606
+ ).success(
607
+ gradio_stream_answer,
608
+ [chatbot, chat_state, img_list, temperature],
609
+ [chatbot, chat_state]
610
+ ).success(
611
+ gradio_visualize,
612
+ [chatbot, image],
613
+ [chatbot],
614
+ queue=False,
615
  )
 
616
 
617
+ send.click(
618
+ gradio_ask,
619
+ [text_input, chatbot, chat_state, image, img_list, upload_flag, replace_flag],
620
+ [text_input, chatbot, chat_state, img_list, upload_flag, replace_flag], queue=False
621
+ ).success(
622
+ gradio_stream_answer,
623
+ [chatbot, chat_state, img_list, temperature],
624
+ [chatbot, chat_state]
625
+ ).success(
626
+ gradio_visualize,
627
+ [chatbot, image],
628
+ [chatbot],
629
+ queue=False,
630
+ )
631
+
632
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, chat_state, img_list], queue=False)
633
+
634
+ demo.launch(share=True, enable_queue=True)
appv1.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import gradio as gr
8
+ import torch.backends.cudnn as cudnn
9
+ from minigpt4.common.config import Config
10
+ from minigpt4.common.dist_utils import get_rank
11
+ from minigpt4.common.registry import registry
12
+ from minigpt4.conversation.conversation import Chat, CONV_VISION
13
+
14
+ # imports modules for registration
15
+ from minigpt4.datasets.builders import *
16
+ from minigpt4.models import *
17
+ from minigpt4.processors import *
18
+ from minigpt4.runners import *
19
+ from minigpt4.tasks import *
20
+
21
+ def parse_args():
22
+ parser = argparse.ArgumentParser(description="Demo")
23
+ parser.add_argument("--cfg-path", type=str, default='eval_configs/tinygptv_stage1_2_3_eval.yaml', help="path to configuration file.")
24
+ parser.add_argument(
25
+ "--options",
26
+ nargs="+",
27
+ help="override some settings in the used config, the key-value pair "
28
+ "in xxx=yyy format will be merged into config file (deprecate), "
29
+ "change to --cfg-options instead.",
30
+ )
31
+ args = parser.parse_args()
32
+ return args
33
+
34
+
35
+ def setup_seeds(config):
36
+ seed = config.run_cfg.seed + get_rank()
37
+
38
+ random.seed(seed)
39
+ np.random.seed(seed)
40
+ torch.manual_seed(seed)
41
+
42
+ cudnn.benchmark = False
43
+ cudnn.deterministic = True
44
+ # ========================================
45
+ # Model Initialization
46
+ # ========================================
47
+
48
+ SHARED_UI_WARNING = f'''### [NOTE] It is possible that you are waiting in a lengthy queue.
49
+
50
+ You can duplicate and use it with a paid private GPU.
51
+
52
+ <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Vision-CAIR/minigpt4?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a>
53
+
54
+ Alternatively, you can also use the demo on our [project page](https://minigpt-4.github.io).
55
+ '''
56
+
57
+ print('Initializing Chat')
58
+ cfg = Config(parse_args())
59
+
60
+ model_config = cfg.model_cfg
61
+ model_cls = registry.get_model_class(model_config.arch)
62
+ model = model_cls.from_config(model_config).to('cuda:0')
63
+
64
+ vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
65
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
66
+ chat = Chat(model, vis_processor)
67
+ print('Initialization Finished')
68
+
69
+ # ========================================
70
+ # Gradio Setting
71
+ # ========================================
72
+
73
+ def gradio_reset(chat_state, img_list):
74
+ if chat_state is not None:
75
+ chat_state.messages = []
76
+ if img_list is not None:
77
+ img_list = []
78
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', interactive=False), gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
79
+
80
+ def upload_img(gr_img, text_input, chat_state):
81
+ if gr_img is None:
82
+ return None, None, gr.update(interactive=True), chat_state, None
83
+ chat_state = CONV_VISION.copy()
84
+ img_list = []
85
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
86
+
87
+ return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
88
+
89
+ def gradio_ask(user_message, chatbot, chat_state):
90
+ if len(user_message) == 0:
91
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
92
+ chat.ask(user_message, chat_state)
93
+ chatbot = chatbot + [[user_message, None]]
94
+ return '', chatbot, chat_state
95
+
96
+
97
+ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
98
+ llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=1, temperature=temperature, max_length=2000)[0]
99
+ chatbot[-1][1] = llm_message
100
+ return chatbot, chat_state, img_list
101
+
102
+ title = """<h1 align="center">Demo of TinyGPT-V</h1>"""
103
+ description = """<h3>This is the demo of TinyGPT-V. Upload your images and start chatting!</h3>"""
104
+ article = """<div style='display:flex; gap: 0.25rem; '><a href='https://github.com/DLYuanGod/TinyGPT-V'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://arxiv.org/abs/2312.16862'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
105
+ """
106
+
107
+ #TODO show examples below
108
+
109
+ with gr.Blocks() as demo:
110
+ gr.Markdown(title)
111
+
112
+ gr.Markdown(description)
113
+ gr.Markdown(article)
114
+
115
+ with gr.Row():
116
+ with gr.Column(scale=0.5):
117
+ image = gr.Image(type="pil")
118
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
119
+ clear = gr.Button("Restart")
120
+
121
+ num_beams = gr.Slider(
122
+ minimum=1,
123
+ maximum=5,
124
+ value=1,
125
+ step=1,
126
+ interactive=True,
127
+ label="beam search numbers)",
128
+ )
129
+
130
+ temperature = gr.Slider(
131
+ minimum=0.1,
132
+ maximum=2.0,
133
+ value=1.0,
134
+ step=0.1,
135
+ interactive=True,
136
+ label="Temperature",
137
+ )
138
+
139
+
140
+ with gr.Column():
141
+ chat_state = gr.State()
142
+ img_list = gr.State()
143
+ chatbot = gr.Chatbot(label='TinyGPT-V')
144
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
145
+
146
+ upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
147
+
148
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
149
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
150
+ )
151
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
152
+
153
+ demo.launch(enable_queue=True)
eval_configs/tinygptv_stage4_eval.yaml CHANGED
@@ -4,8 +4,8 @@ model:
4
  max_txt_len: 500
5
  bos_token_id: "###"
6
  low_resource: False
7
- prompt_template: 'Instruct: {} /n Output: '
8
- ckpt: "/home/li0007xu/LLM/TinyGPT-V/TinyGPT-V_for_Stage4.pth"
9
  lora_r: 64
10
  lora_alpha: 16
11
 
 
4
  max_txt_len: 500
5
  bos_token_id: "###"
6
  low_resource: False
7
+ prompt_template: '###Human: {} ###Assistant: '
8
+ ckpt: "checkpoint_49.pth"
9
  lora_r: 64
10
  lora_alpha: 16
11