zhangtao-whu commited on
Commit
c26bbc8
·
verified ·
1 Parent(s): 2c84b32

Upload /app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +521 -54
app.py CHANGED
@@ -1,63 +1,530 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
8
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ import cv2
4
+ import random
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torch.nn.functional as F
8
+ import sys
9
+ from omg_llava.tools.app_utils import process_markdown, show_mask_pred, parse_visual_prompts
10
 
11
+ import torch
12
+ from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
13
+ BitsAndBytesConfig, CLIPImageProcessor,
14
+ CLIPVisionModel, GenerationConfig)
15
+ from transformers.generation.streamers import TextStreamer
16
 
17
+ from xtuner.dataset.utils import expand2square, load_image
18
+ from omg_llava.dataset.utils import expand2square_bbox, expand2square_mask, expand2square_points
19
+ from xtuner.model.utils import prepare_inputs_labels_for_multimodal
20
+ from omg_llava.model.utils import prepare_inputs_labels_for_multimodal_with_visual_prompts
21
+ from xtuner.tools.utils import get_stop_criteria
22
+ from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
23
+ PROMPT_TEMPLATE, SYSTEM_TEMPLATE)
24
 
25
+ import argparse
26
+ import os.path as osp
27
+
28
+ from mmengine.config import Config, DictAction
29
+ from mmengine.fileio import PetrelBackend, get_file_backend
30
+
31
+ from xtuner.configs import cfgs_name_path
32
+ from xtuner.model.utils import guess_load_checkpoint
33
+ from xtuner.registry import BUILDER
34
+
35
+ from gradio_image_prompter import ImagePrompter
36
+
37
+ TORCH_DTYPE_MAP = dict(
38
+ fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
39
+
40
+ def parse_args(args):
41
+ parser = argparse.ArgumentParser(description="OMG-LLaVA Demo")
42
+ parser.add_argument('--config', help='config file name or path.',
43
+ default='./omg_llava/configs/finetune/hf_app.py')
44
+ parser.add_argument('--pth_model', help='pth model file',
45
+ default='./pretrained/omg_llava/omg_llava_fintune_8gpus.pth')
46
+
47
+ parser.add_argument('--image', default=None, help='image')
48
+ parser.add_argument(
49
+ '--torch-dtype',
50
+ default='fp16',
51
+ choices=TORCH_DTYPE_MAP.keys(),
52
+ help='Override the default `torch.dtype` and load the model under '
53
+ 'a specific `dtype`.')
54
+ parser.add_argument(
55
+ '--prompt-template',
56
+ choices=PROMPT_TEMPLATE.keys(),
57
+ default="internlm2_chat",
58
+ help='Specify a prompt template')
59
+ system_group = parser.add_mutually_exclusive_group()
60
+ system_group.add_argument(
61
+ '--system', default=None, help='Specify the system text')
62
+ system_group.add_argument(
63
+ '--system-template',
64
+ choices=SYSTEM_TEMPLATE.keys(),
65
+ default=None,
66
+ help='Specify a system template')
67
+ parser.add_argument(
68
+ '--bits',
69
+ type=int,
70
+ choices=[4, 8, None],
71
+ default=None,
72
+ help='LLM bits')
73
+ parser.add_argument(
74
+ '--bot-name', type=str, default='BOT', help='Name for Bot')
75
+ parser.add_argument(
76
+ '--with-plugins',
77
+ nargs='+',
78
+ choices=['calculate', 'solve', 'search'],
79
+ help='Specify plugins to use')
80
+ parser.add_argument(
81
+ '--no-streamer', action='store_true', help='Whether to with streamer')
82
+ parser.add_argument(
83
+ '--lagent', action='store_true', help='Whether to use lagent')
84
+ parser.add_argument(
85
+ '--stop-words', nargs='+', type=str, default=[], help='Stop words')
86
+ parser.add_argument(
87
+ '--offload-folder',
88
+ default=None,
89
+ help='The folder in which to offload the model weights (or where the '
90
+ 'model weights are already offloaded).')
91
+ parser.add_argument(
92
+ '--max-new-tokens',
93
+ type=int,
94
+ default=2048,
95
+ help='Maximum number of new tokens allowed in generated text')
96
+ parser.add_argument(
97
+ '--temperature',
98
+ type=float,
99
+ default=0.1,
100
+ help='The value used to modulate the next token probabilities.')
101
+ parser.add_argument(
102
+ '--top-k',
103
+ type=int,
104
+ default=40,
105
+ help='The number of highest probability vocabulary tokens to '
106
+ 'keep for top-k-filtering.')
107
+ parser.add_argument(
108
+ '--top-p',
109
+ type=float,
110
+ default=0.75,
111
+ help='If set to float < 1, only the smallest set of most probable '
112
+ 'tokens with probabilities that add up to top_p or higher are '
113
+ 'kept for generation.')
114
+ parser.add_argument(
115
+ '--repetition-penalty',
116
+ type=float,
117
+ default=1.0,
118
+ help='The parameter for repetition penalty. 1.0 means no penalty.')
119
+ parser.add_argument(
120
+ '--seed',
121
+ type=int,
122
+ default=0,
123
+ help='Random seed for reproducible text generation')
124
+ return parser.parse_args(args)
125
+
126
+ def get_points_embeddings(points, input_ids, width, height,
127
+ mark_token_idx, mode='point'):
128
+ if points is None or len(points) == 0:
129
+ return []
130
+
131
+ mark_token_mask = input_ids == mark_token_idx
132
+ batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(
133
+ input_ids.device)
134
+ batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number
135
+
136
+ points = points.to(torch.float32)
137
+ # print(points.dtype, batch_idxs.dtype)
138
+
139
+ if mode == 'point':
140
+ marks_embeddings = visual_encoder.forward_point_sam(
141
+ points, batch_idxs, width=width, height=height
142
+ )[:, 0] # (N, C)
143
+ elif mode == 'box':
144
+ marks_embeddings = visual_encoder.forward_box_sam(
145
+ points, batch_idxs, width=width, height=height
146
+ )[:, 0] # (N, C)
147
+ else:
148
+ raise NotImplementedError
149
+
150
+ marks_embeddings = marks_embeddings.to(projector.model.query_proj.weight.dtype)
151
+ marks_embeddings = projector.model.query_proj(marks_embeddings)
152
+ marks_embeddings = projector.model.model(marks_embeddings)
153
+ print('marks_embeddings shape: ', marks_embeddings.shape)
154
+ return marks_embeddings # (N, C)
155
+
156
+ def get_visual_prompts_embeddings(
157
+ height, width, input_ids,
158
  ):
159
+ points_prompts = global_infos.point_prompts
160
+ boxes_prompts = global_infos.box_prompts
161
+
162
+ if len(points_prompts) == 0:
163
+ points_mark_embedding = []
164
+ else:
165
+ points = np.array(points_prompts)
166
+ points = expand2square_points(points, height=height, width=width)
167
+ points[:, 0] = points[:, 0] / max(height, width) * 1024
168
+ points[:, 1] = points[:, 1] / max(height, width) * 1024
169
+ points = torch.from_numpy(points)
170
+ points = points.cuda()
171
+ mark_token_id = omg_llava.mark_token_idx
172
+
173
+ points_mark_embedding = get_points_embeddings(
174
+ points, input_ids,
175
+ 1024, 1024,
176
+ mark_token_id)
177
+
178
+
179
+ if len(boxes_prompts) == 0:
180
+ boxes_mark_embedding = []
181
+ else:
182
+ boxes_prompts = np.array(boxes_prompts)
183
+
184
+ boxes_prompts = expand2square_bbox(boxes_prompts, height=height, width=width)
185
+ boxes_prompts[:, [0, 2]] = boxes_prompts[:, [0, 2]] / max(height, width) * 1024
186
+ boxes_prompts[:, [1, 3]] = boxes_prompts[:, [1, 3]] / max(height, width) * 1024
187
+ boxes_prompts = torch.from_numpy(boxes_prompts)
188
+ boxes_prompts = torch.from_numpy(boxes_prompts)
189
+ boxes_prompts = boxes_prompts.cuda()
190
+ # using <region> token
191
+ region_token_id = omg_llava.region_token_idx
192
+
193
+ boxes_mark_embedding = get_points_embeddings(
194
+ boxes_prompts, input_ids,
195
+ 1024, 1024,
196
+ region_token_id)
197
+ return points_mark_embedding, boxes_mark_embedding
198
+
199
+ def inference(input_str, all_inputs, follow_up):
200
+ input_str = input_str.replace('<point>', '<mark>')\
201
+ .replace('<box>', '<region>')
202
+ print("Get Recieved Infos !!!")
203
+ prompts = all_inputs['points']
204
+ visual_prompts = parse_visual_prompts(prompts)
205
+ input_image = all_inputs['image']
206
+
207
+ print("follow_up: ", follow_up)
208
+ print(prompts)
209
+ print("input_str: ", input_str, "input_image: ", input_image)
210
+
211
+ #
212
+ if not follow_up:
213
+ # reset
214
+ print('Log: History responses have been removed!')
215
+ global_infos.n_turn = 0
216
+ global_infos.inputs = ''
217
+ # reset prompts
218
+ global_infos.point_prompts = []
219
+ global_infos.box_prompts = []
220
+ global_infos.mask_prompts = []
221
+
222
+ # first conversation, add image tokens
223
+ text = DEFAULT_IMAGE_TOKEN + '\n' + input_str
224
+
225
+ # prepare image
226
+ image = load_image(input_image)
227
+ width, height = image.size
228
+ global_infos.image_width = width
229
+ global_infos.image_height = height
230
+ image = expand2square(
231
+ image, tuple(int(x * 255) for x in image_processor.image_mean))
232
+ global_infos.image_for_show = image
233
+ image = image_processor.preprocess(
234
+ image, return_tensors='pt')['pixel_values'][0]
235
+ image = image.cuda().unsqueeze(0).to(visual_encoder.dtype)
236
+ visual_outputs = visual_encoder(image, output_hidden_states=True)
237
+ pixel_values = projector(visual_outputs)
238
+ global_infos.panoptic_masks = omg_llava.visual_encoder.vis_binary_masks
239
+ global_infos.pixel_values = pixel_values
240
+
241
+ # for remove padding
242
+ if width == height:
243
+ sx, ex, sy, ey = 0, width, 0, height
244
+ elif width > height:
245
+ sy = int((width - height) / 2.0)
246
+ ey = width - sy
247
+ sx, ex = 0, width
248
+ else:
249
+ sx = int((height - width) / 2.0)
250
+ ex = height - sx
251
+ sy, ey = 0, height
252
+
253
+ global_infos.sx = sx
254
+ global_infos.sy = sy
255
+ global_infos.ex = ex
256
+ global_infos.ey = ey
257
+
258
+ else:
259
+ text = input_str
260
+ pixel_values = global_infos.pixel_values
261
+
262
+ # add cur prompts into global prompts
263
+ global_infos.point_prompts += visual_prompts['points']
264
+ global_infos.box_prompts += visual_prompts['boxes']
265
+
266
+ if args.prompt_template:
267
+ prompt_text = ''
268
+ template = PROMPT_TEMPLATE[args.prompt_template]
269
+ if 'SYSTEM' in template and global_infos.n_turn == 0:
270
+ system_text = None
271
+ if args.system_template is not None:
272
+ system_text = SYSTEM_TEMPLATE[
273
+ args.system_template].format(
274
+ round=global_infos.n_turn + 1, bot_name=args.bot_name)
275
+ elif args.system is not None:
276
+ system_text = args.system
277
+ if system_text is not None:
278
+ prompt_text += template['SYSTEM'].format(
279
+ system=system_text,
280
+ round=global_infos.n_turn + 1,
281
+ bot_name=args.bot_name)
282
+ prompt_text += template['INSTRUCTION'].format(
283
+ input=text, round=global_infos.n_turn + 1, bot_name=args.bot_name)
284
+ else:
285
+ prompt_text = text
286
+
287
+ print("prompt_text: ", prompt_text)
288
+ global_infos.inputs += prompt_text
289
+
290
+ # encode prompt text
291
+ chunk_encode = []
292
+ for idx, chunk in enumerate(global_infos.inputs.split(DEFAULT_IMAGE_TOKEN)):
293
+ if idx == 0 and global_infos.n_turn == 0:
294
+ cur_encode = tokenizer.encode(chunk)
295
+ else:
296
+ cur_encode = tokenizer.encode(
297
+ chunk, add_special_tokens=False)
298
+ chunk_encode.append(cur_encode)
299
+ assert len(chunk_encode) == 2
300
+ ids = []
301
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
302
+ ids.extend(cur_chunk_encode)
303
+ if idx != len(chunk_encode) - 1:
304
+ ids.append(IMAGE_TOKEN_INDEX)
305
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
306
+
307
+ points_mark_embeddings, boxes_mark_embeddings = get_visual_prompts_embeddings(
308
+ height=global_infos.image_height,
309
+ width=global_infos.image_width, input_ids=ids
310
+ )
311
+
312
+ mark_embeddings = points_mark_embeddings
313
+
314
+ mark_token_id = omg_llava.mark_token_idx
315
+ mm_inputs = prepare_inputs_labels_for_multimodal_with_visual_prompts(
316
+ llm=llm, input_ids=ids, pixel_values=pixel_values,
317
+ mark_id=mark_token_id,
318
+ mark_feats=mark_embeddings, region_id=-9999)
319
 
320
+ # mm_inputs['inputs_embeds'] = mm_inputs['inputs_embeds'].to(torch.float16)
321
+
322
+ generate_output = llm.generate(
323
+ **mm_inputs,
324
+ generation_config=gen_config,
325
+ streamer=streamer,
326
+ bos_token_id=tokenizer.bos_token_id,
327
+ stopping_criteria=stop_criteria,
328
+ output_hidden_states=True,
329
+ return_dict_in_generate=True
330
+ )
331
+
332
+ predict = tokenizer.decode(
333
+ generate_output.sequences[0])
334
+
335
+ global_infos.inputs += predict
336
+ predict = predict.strip()
337
+ global_infos.n_turn += 1
338
+ global_infos.inputs += sep
339
+ if len(generate_output.sequences[0]) >= args.max_new_tokens:
340
+ print(
341
+ 'Remove the memory of history responses, since '
342
+ f'it exceeds the length limitation {args.max_new_tokens}.')
343
+ global_infos.n_turn = 0
344
+ global_infos.inputs = ''
345
+
346
+ hidden_states = generate_output.hidden_states
347
+ last_hidden_states = [item[-1][0] for item in hidden_states]
348
+ last_hidden_states = torch.cat(last_hidden_states, dim=0)
349
+ seg_hidden_states = get_seg_hidden_states(
350
+ last_hidden_states, generate_output.sequences[0][:-1],
351
+ seg_id=omg_llava.seg_token_idx
352
+ )
353
+ # seg_hidden_states = seg_hidden_states.to(torch.float32)
354
+ if len(seg_hidden_states) != 0:
355
+ seg_hidden_states = projector_text2vision(seg_hidden_states)
356
+ batch_idxs = torch.zeros((seg_hidden_states.shape[0],),
357
+ dtype=torch.int64).to(seg_hidden_states.device)
358
+ pred_masks_list = omg_llava.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs)
359
+ print((pred_masks_list[-1].flatten(2) > 0).sum(-1))
360
+ print(pred_masks_list[-1].shape)
361
+ image_mask_show, selected_colors = show_mask_pred(
362
+ global_infos.image_for_show, pred_masks_list[-1],
363
+ crop_range = (global_infos.sx, global_infos.ex, global_infos.sy, global_infos.ey)
364
+ )
365
+ else:
366
+ image_mask_show = global_infos.image_for_show.crop(
367
+ (global_infos.sx, global_infos.sy, global_infos.ex, global_infos.ey))
368
+ selected_colors = []
369
+
370
+ panoptic_show, _ = show_mask_pred(
371
+ global_infos.image_for_show, global_infos.panoptic_masks,
372
+ crop_range=(global_infos.sx, global_infos.ex, global_infos.sy, global_infos.ey)
373
+ )
374
+
375
+ predict = process_markdown(predict, selected_colors)
376
+ # return panoptic_show, image_mask_show, predict
377
+ return image_mask_show, predict
378
+
379
+ def init_models(args):
380
+ torch.manual_seed(args.seed)
381
+
382
+ # parse config
383
+ if not osp.isfile(args.config):
384
+ try:
385
+ args.config = cfgs_name_path[args.config]
386
+ except KeyError:
387
+ raise FileNotFoundError(f'Cannot find {args.config}')
388
+
389
+ # load config
390
+ cfg = Config.fromfile(args.config)
391
+
392
+ model_name = cfg.model.type if isinstance(cfg.model.type,
393
+ str) else cfg.model.type.__name__
394
+ if 'LLaVAModel' or 'OMG' in model_name:
395
+ cfg.model.pretrained_pth = None
396
+
397
+ model = BUILDER.build(cfg.model)
398
+
399
+ backend = get_file_backend(args.pth_model)
400
+ if isinstance(backend, PetrelBackend):
401
+ from xtuner.utils.fileio import patch_fileio
402
+ with patch_fileio():
403
+ state_dict = guess_load_checkpoint(args.pth_model)
404
+ else:
405
+ state_dict = guess_load_checkpoint(args.pth_model)
406
+
407
+ model.load_state_dict(state_dict, strict=False)
408
+ print(f'Load PTH model from {args.pth_model}')
409
+
410
+ image_processor = cfg.image_processor
411
+ image_processor_type = image_processor['type']
412
+ del image_processor['type']
413
+ image_processor = image_processor_type(**image_processor)
414
+
415
+ # build llm
416
+ quantization_config = None
417
+ load_in_8bit = False
418
+ if args.bits == 4:
419
+ quantization_config = BitsAndBytesConfig(
420
+ load_in_4bit=True,
421
+ load_in_8bit=False,
422
+ llm_int8_threshold=6.0,
423
+ llm_int8_has_fp16_weight=False,
424
+ bnb_4bit_compute_dtype=torch.float16,
425
+ bnb_4bit_use_double_quant=True,
426
+ bnb_4bit_quant_type='nf4')
427
+ elif args.bits == 8:
428
+ load_in_8bit = True
429
+ model_kwargs = {
430
+ 'quantization_config': quantization_config,
431
+ 'load_in_8bit': load_in_8bit,
432
+ 'device_map': 'auto',
433
+ 'offload_folder': args.offload_folder,
434
+ 'trust_remote_code': True,
435
+ 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
436
+ }
437
+
438
+ inner_thoughts_open = False
439
+ calculate_open = False
440
+ solve_open = False
441
+ search_open = False
442
+
443
+ # build llm
444
+ llm = model.llm
445
+ tokenizer = model.tokenizer
446
+
447
+ model.cuda()
448
+ model.eval()
449
+ llm.eval()
450
+ visual_encoder = model.visual_encoder
451
+ projector = model.projector
452
+ projector_text2vision = model.projector_text2vision
453
+
454
+ visual_encoder.eval()
455
+ projector.eval()
456
+ projector_text2vision.eval()
457
+
458
+ return model, llm, tokenizer, image_processor, visual_encoder, projector, projector_text2vision
459
+
460
+ def get_seg_hidden_states(hidden_states, output_ids, seg_id):
461
+ seg_mask = output_ids == seg_id
462
+ n_out = len(seg_mask)
463
+ print(output_ids)
464
+ return hidden_states[-n_out:][seg_mask]
465
+
466
+ class global_infos:
467
+ inputs = ''
468
+ n_turn = 0
469
+ image_width = 0
470
+ image_height = 0
471
+
472
+ image_for_show = None
473
+ pixel_values = None
474
+ panoptic_masks = None
475
+
476
+ sx, sy, ex, ey = 0, 0 ,1024, 1024
477
+
478
+ point_prompts = []
479
+ box_prompts = []
480
+ mask_prompts = []
481
 
482
  if __name__ == "__main__":
483
+ # get parse args and set models
484
+ args = parse_args(sys.argv[1:])
485
+
486
+ omg_llava, llm, tokenizer, image_processor, visual_encoder, projector, projector_text2vision = \
487
+ init_models(args)
488
+
489
+ stop_words = args.stop_words
490
+ sep = ''
491
+ if args.prompt_template:
492
+ template = PROMPT_TEMPLATE[args.prompt_template]
493
+ stop_words += template.get('STOP_WORDS', [])
494
+ sep = template.get('SEP', '')
495
+ stop_criteria = get_stop_criteria(
496
+ tokenizer=tokenizer, stop_words=stop_words)
497
+
498
+ if args.no_streamer:
499
+ streamer = None
500
+ else:
501
+ streamer = TextStreamer(tokenizer, skip_prompt=True)
502
+
503
+ gen_config = GenerationConfig(
504
+ max_new_tokens=args.max_new_tokens,
505
+ do_sample=args.temperature > 0,
506
+ temperature=args.temperature,
507
+ top_p=args.top_p,
508
+ top_k=args.top_k,
509
+ repetition_penalty=args.repetition_penalty,
510
+ eos_token_id=tokenizer.eos_token_id,
511
+ pad_token_id=tokenizer.pad_token_id
512
+ if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
513
+ )
514
+
515
+ demo = gr.Interface(
516
+ inference, inputs=[gr.Textbox(lines=1, placeholder=None, label="Text Instruction"), ImagePrompter(
517
+ type='filepath', label='Input Image (Please click points or draw bboxes)', interactive=True,
518
+ elem_id='image_upload', height=360, visible=True, render=True
519
+ ),
520
+ gr.Checkbox(label="Follow up Question")],
521
+ outputs=[
522
+ # gr.Image(type="pil", label="Panoptic Segmentation", height=360),
523
+ gr.Image(type="pil", label="Output Image"),
524
+ gr.Markdown()],
525
+ theme=gr.themes.Soft(), allow_flagging="auto", )
526
+
527
+ demo.queue()
528
+ demo.launch(share=True)
529
+
530
+