Spaces:
Runtime error
Runtime error
Upload /app.py with huggingface_hub
Browse files
app.py
CHANGED
@@ -1,63 +1,530 @@
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import InferenceClient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
):
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
if __name__ == "__main__":
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from huggingface_hub import InferenceClient
|
3 |
+
import cv2
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import sys
|
9 |
+
from omg_llava.tools.app_utils import process_markdown, show_mask_pred, parse_visual_prompts
|
10 |
|
11 |
+
import torch
|
12 |
+
from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
|
13 |
+
BitsAndBytesConfig, CLIPImageProcessor,
|
14 |
+
CLIPVisionModel, GenerationConfig)
|
15 |
+
from transformers.generation.streamers import TextStreamer
|
16 |
|
17 |
+
from xtuner.dataset.utils import expand2square, load_image
|
18 |
+
from omg_llava.dataset.utils import expand2square_bbox, expand2square_mask, expand2square_points
|
19 |
+
from xtuner.model.utils import prepare_inputs_labels_for_multimodal
|
20 |
+
from omg_llava.model.utils import prepare_inputs_labels_for_multimodal_with_visual_prompts
|
21 |
+
from xtuner.tools.utils import get_stop_criteria
|
22 |
+
from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
|
23 |
+
PROMPT_TEMPLATE, SYSTEM_TEMPLATE)
|
24 |
|
25 |
+
import argparse
|
26 |
+
import os.path as osp
|
27 |
+
|
28 |
+
from mmengine.config import Config, DictAction
|
29 |
+
from mmengine.fileio import PetrelBackend, get_file_backend
|
30 |
+
|
31 |
+
from xtuner.configs import cfgs_name_path
|
32 |
+
from xtuner.model.utils import guess_load_checkpoint
|
33 |
+
from xtuner.registry import BUILDER
|
34 |
+
|
35 |
+
from gradio_image_prompter import ImagePrompter
|
36 |
+
|
37 |
+
TORCH_DTYPE_MAP = dict(
|
38 |
+
fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
|
39 |
+
|
40 |
+
def parse_args(args):
|
41 |
+
parser = argparse.ArgumentParser(description="OMG-LLaVA Demo")
|
42 |
+
parser.add_argument('--config', help='config file name or path.',
|
43 |
+
default='./omg_llava/configs/finetune/hf_app.py')
|
44 |
+
parser.add_argument('--pth_model', help='pth model file',
|
45 |
+
default='./pretrained/omg_llava/omg_llava_fintune_8gpus.pth')
|
46 |
+
|
47 |
+
parser.add_argument('--image', default=None, help='image')
|
48 |
+
parser.add_argument(
|
49 |
+
'--torch-dtype',
|
50 |
+
default='fp16',
|
51 |
+
choices=TORCH_DTYPE_MAP.keys(),
|
52 |
+
help='Override the default `torch.dtype` and load the model under '
|
53 |
+
'a specific `dtype`.')
|
54 |
+
parser.add_argument(
|
55 |
+
'--prompt-template',
|
56 |
+
choices=PROMPT_TEMPLATE.keys(),
|
57 |
+
default="internlm2_chat",
|
58 |
+
help='Specify a prompt template')
|
59 |
+
system_group = parser.add_mutually_exclusive_group()
|
60 |
+
system_group.add_argument(
|
61 |
+
'--system', default=None, help='Specify the system text')
|
62 |
+
system_group.add_argument(
|
63 |
+
'--system-template',
|
64 |
+
choices=SYSTEM_TEMPLATE.keys(),
|
65 |
+
default=None,
|
66 |
+
help='Specify a system template')
|
67 |
+
parser.add_argument(
|
68 |
+
'--bits',
|
69 |
+
type=int,
|
70 |
+
choices=[4, 8, None],
|
71 |
+
default=None,
|
72 |
+
help='LLM bits')
|
73 |
+
parser.add_argument(
|
74 |
+
'--bot-name', type=str, default='BOT', help='Name for Bot')
|
75 |
+
parser.add_argument(
|
76 |
+
'--with-plugins',
|
77 |
+
nargs='+',
|
78 |
+
choices=['calculate', 'solve', 'search'],
|
79 |
+
help='Specify plugins to use')
|
80 |
+
parser.add_argument(
|
81 |
+
'--no-streamer', action='store_true', help='Whether to with streamer')
|
82 |
+
parser.add_argument(
|
83 |
+
'--lagent', action='store_true', help='Whether to use lagent')
|
84 |
+
parser.add_argument(
|
85 |
+
'--stop-words', nargs='+', type=str, default=[], help='Stop words')
|
86 |
+
parser.add_argument(
|
87 |
+
'--offload-folder',
|
88 |
+
default=None,
|
89 |
+
help='The folder in which to offload the model weights (or where the '
|
90 |
+
'model weights are already offloaded).')
|
91 |
+
parser.add_argument(
|
92 |
+
'--max-new-tokens',
|
93 |
+
type=int,
|
94 |
+
default=2048,
|
95 |
+
help='Maximum number of new tokens allowed in generated text')
|
96 |
+
parser.add_argument(
|
97 |
+
'--temperature',
|
98 |
+
type=float,
|
99 |
+
default=0.1,
|
100 |
+
help='The value used to modulate the next token probabilities.')
|
101 |
+
parser.add_argument(
|
102 |
+
'--top-k',
|
103 |
+
type=int,
|
104 |
+
default=40,
|
105 |
+
help='The number of highest probability vocabulary tokens to '
|
106 |
+
'keep for top-k-filtering.')
|
107 |
+
parser.add_argument(
|
108 |
+
'--top-p',
|
109 |
+
type=float,
|
110 |
+
default=0.75,
|
111 |
+
help='If set to float < 1, only the smallest set of most probable '
|
112 |
+
'tokens with probabilities that add up to top_p or higher are '
|
113 |
+
'kept for generation.')
|
114 |
+
parser.add_argument(
|
115 |
+
'--repetition-penalty',
|
116 |
+
type=float,
|
117 |
+
default=1.0,
|
118 |
+
help='The parameter for repetition penalty. 1.0 means no penalty.')
|
119 |
+
parser.add_argument(
|
120 |
+
'--seed',
|
121 |
+
type=int,
|
122 |
+
default=0,
|
123 |
+
help='Random seed for reproducible text generation')
|
124 |
+
return parser.parse_args(args)
|
125 |
+
|
126 |
+
def get_points_embeddings(points, input_ids, width, height,
|
127 |
+
mark_token_idx, mode='point'):
|
128 |
+
if points is None or len(points) == 0:
|
129 |
+
return []
|
130 |
+
|
131 |
+
mark_token_mask = input_ids == mark_token_idx
|
132 |
+
batch_idxs = torch.arange(input_ids.shape[0]).unsqueeze(1).repeat(1, input_ids.shape[1]).to(
|
133 |
+
input_ids.device)
|
134 |
+
batch_idxs = batch_idxs[mark_token_mask] # (N, ) batch_size number
|
135 |
+
|
136 |
+
points = points.to(torch.float32)
|
137 |
+
# print(points.dtype, batch_idxs.dtype)
|
138 |
+
|
139 |
+
if mode == 'point':
|
140 |
+
marks_embeddings = visual_encoder.forward_point_sam(
|
141 |
+
points, batch_idxs, width=width, height=height
|
142 |
+
)[:, 0] # (N, C)
|
143 |
+
elif mode == 'box':
|
144 |
+
marks_embeddings = visual_encoder.forward_box_sam(
|
145 |
+
points, batch_idxs, width=width, height=height
|
146 |
+
)[:, 0] # (N, C)
|
147 |
+
else:
|
148 |
+
raise NotImplementedError
|
149 |
+
|
150 |
+
marks_embeddings = marks_embeddings.to(projector.model.query_proj.weight.dtype)
|
151 |
+
marks_embeddings = projector.model.query_proj(marks_embeddings)
|
152 |
+
marks_embeddings = projector.model.model(marks_embeddings)
|
153 |
+
print('marks_embeddings shape: ', marks_embeddings.shape)
|
154 |
+
return marks_embeddings # (N, C)
|
155 |
+
|
156 |
+
def get_visual_prompts_embeddings(
|
157 |
+
height, width, input_ids,
|
158 |
):
|
159 |
+
points_prompts = global_infos.point_prompts
|
160 |
+
boxes_prompts = global_infos.box_prompts
|
161 |
+
|
162 |
+
if len(points_prompts) == 0:
|
163 |
+
points_mark_embedding = []
|
164 |
+
else:
|
165 |
+
points = np.array(points_prompts)
|
166 |
+
points = expand2square_points(points, height=height, width=width)
|
167 |
+
points[:, 0] = points[:, 0] / max(height, width) * 1024
|
168 |
+
points[:, 1] = points[:, 1] / max(height, width) * 1024
|
169 |
+
points = torch.from_numpy(points)
|
170 |
+
points = points.cuda()
|
171 |
+
mark_token_id = omg_llava.mark_token_idx
|
172 |
+
|
173 |
+
points_mark_embedding = get_points_embeddings(
|
174 |
+
points, input_ids,
|
175 |
+
1024, 1024,
|
176 |
+
mark_token_id)
|
177 |
+
|
178 |
+
|
179 |
+
if len(boxes_prompts) == 0:
|
180 |
+
boxes_mark_embedding = []
|
181 |
+
else:
|
182 |
+
boxes_prompts = np.array(boxes_prompts)
|
183 |
+
|
184 |
+
boxes_prompts = expand2square_bbox(boxes_prompts, height=height, width=width)
|
185 |
+
boxes_prompts[:, [0, 2]] = boxes_prompts[:, [0, 2]] / max(height, width) * 1024
|
186 |
+
boxes_prompts[:, [1, 3]] = boxes_prompts[:, [1, 3]] / max(height, width) * 1024
|
187 |
+
boxes_prompts = torch.from_numpy(boxes_prompts)
|
188 |
+
boxes_prompts = torch.from_numpy(boxes_prompts)
|
189 |
+
boxes_prompts = boxes_prompts.cuda()
|
190 |
+
# using <region> token
|
191 |
+
region_token_id = omg_llava.region_token_idx
|
192 |
+
|
193 |
+
boxes_mark_embedding = get_points_embeddings(
|
194 |
+
boxes_prompts, input_ids,
|
195 |
+
1024, 1024,
|
196 |
+
region_token_id)
|
197 |
+
return points_mark_embedding, boxes_mark_embedding
|
198 |
+
|
199 |
+
def inference(input_str, all_inputs, follow_up):
|
200 |
+
input_str = input_str.replace('<point>', '<mark>')\
|
201 |
+
.replace('<box>', '<region>')
|
202 |
+
print("Get Recieved Infos !!!")
|
203 |
+
prompts = all_inputs['points']
|
204 |
+
visual_prompts = parse_visual_prompts(prompts)
|
205 |
+
input_image = all_inputs['image']
|
206 |
+
|
207 |
+
print("follow_up: ", follow_up)
|
208 |
+
print(prompts)
|
209 |
+
print("input_str: ", input_str, "input_image: ", input_image)
|
210 |
+
|
211 |
+
#
|
212 |
+
if not follow_up:
|
213 |
+
# reset
|
214 |
+
print('Log: History responses have been removed!')
|
215 |
+
global_infos.n_turn = 0
|
216 |
+
global_infos.inputs = ''
|
217 |
+
# reset prompts
|
218 |
+
global_infos.point_prompts = []
|
219 |
+
global_infos.box_prompts = []
|
220 |
+
global_infos.mask_prompts = []
|
221 |
+
|
222 |
+
# first conversation, add image tokens
|
223 |
+
text = DEFAULT_IMAGE_TOKEN + '\n' + input_str
|
224 |
+
|
225 |
+
# prepare image
|
226 |
+
image = load_image(input_image)
|
227 |
+
width, height = image.size
|
228 |
+
global_infos.image_width = width
|
229 |
+
global_infos.image_height = height
|
230 |
+
image = expand2square(
|
231 |
+
image, tuple(int(x * 255) for x in image_processor.image_mean))
|
232 |
+
global_infos.image_for_show = image
|
233 |
+
image = image_processor.preprocess(
|
234 |
+
image, return_tensors='pt')['pixel_values'][0]
|
235 |
+
image = image.cuda().unsqueeze(0).to(visual_encoder.dtype)
|
236 |
+
visual_outputs = visual_encoder(image, output_hidden_states=True)
|
237 |
+
pixel_values = projector(visual_outputs)
|
238 |
+
global_infos.panoptic_masks = omg_llava.visual_encoder.vis_binary_masks
|
239 |
+
global_infos.pixel_values = pixel_values
|
240 |
+
|
241 |
+
# for remove padding
|
242 |
+
if width == height:
|
243 |
+
sx, ex, sy, ey = 0, width, 0, height
|
244 |
+
elif width > height:
|
245 |
+
sy = int((width - height) / 2.0)
|
246 |
+
ey = width - sy
|
247 |
+
sx, ex = 0, width
|
248 |
+
else:
|
249 |
+
sx = int((height - width) / 2.0)
|
250 |
+
ex = height - sx
|
251 |
+
sy, ey = 0, height
|
252 |
+
|
253 |
+
global_infos.sx = sx
|
254 |
+
global_infos.sy = sy
|
255 |
+
global_infos.ex = ex
|
256 |
+
global_infos.ey = ey
|
257 |
+
|
258 |
+
else:
|
259 |
+
text = input_str
|
260 |
+
pixel_values = global_infos.pixel_values
|
261 |
+
|
262 |
+
# add cur prompts into global prompts
|
263 |
+
global_infos.point_prompts += visual_prompts['points']
|
264 |
+
global_infos.box_prompts += visual_prompts['boxes']
|
265 |
+
|
266 |
+
if args.prompt_template:
|
267 |
+
prompt_text = ''
|
268 |
+
template = PROMPT_TEMPLATE[args.prompt_template]
|
269 |
+
if 'SYSTEM' in template and global_infos.n_turn == 0:
|
270 |
+
system_text = None
|
271 |
+
if args.system_template is not None:
|
272 |
+
system_text = SYSTEM_TEMPLATE[
|
273 |
+
args.system_template].format(
|
274 |
+
round=global_infos.n_turn + 1, bot_name=args.bot_name)
|
275 |
+
elif args.system is not None:
|
276 |
+
system_text = args.system
|
277 |
+
if system_text is not None:
|
278 |
+
prompt_text += template['SYSTEM'].format(
|
279 |
+
system=system_text,
|
280 |
+
round=global_infos.n_turn + 1,
|
281 |
+
bot_name=args.bot_name)
|
282 |
+
prompt_text += template['INSTRUCTION'].format(
|
283 |
+
input=text, round=global_infos.n_turn + 1, bot_name=args.bot_name)
|
284 |
+
else:
|
285 |
+
prompt_text = text
|
286 |
+
|
287 |
+
print("prompt_text: ", prompt_text)
|
288 |
+
global_infos.inputs += prompt_text
|
289 |
+
|
290 |
+
# encode prompt text
|
291 |
+
chunk_encode = []
|
292 |
+
for idx, chunk in enumerate(global_infos.inputs.split(DEFAULT_IMAGE_TOKEN)):
|
293 |
+
if idx == 0 and global_infos.n_turn == 0:
|
294 |
+
cur_encode = tokenizer.encode(chunk)
|
295 |
+
else:
|
296 |
+
cur_encode = tokenizer.encode(
|
297 |
+
chunk, add_special_tokens=False)
|
298 |
+
chunk_encode.append(cur_encode)
|
299 |
+
assert len(chunk_encode) == 2
|
300 |
+
ids = []
|
301 |
+
for idx, cur_chunk_encode in enumerate(chunk_encode):
|
302 |
+
ids.extend(cur_chunk_encode)
|
303 |
+
if idx != len(chunk_encode) - 1:
|
304 |
+
ids.append(IMAGE_TOKEN_INDEX)
|
305 |
+
ids = torch.tensor(ids).cuda().unsqueeze(0)
|
306 |
+
|
307 |
+
points_mark_embeddings, boxes_mark_embeddings = get_visual_prompts_embeddings(
|
308 |
+
height=global_infos.image_height,
|
309 |
+
width=global_infos.image_width, input_ids=ids
|
310 |
+
)
|
311 |
+
|
312 |
+
mark_embeddings = points_mark_embeddings
|
313 |
+
|
314 |
+
mark_token_id = omg_llava.mark_token_idx
|
315 |
+
mm_inputs = prepare_inputs_labels_for_multimodal_with_visual_prompts(
|
316 |
+
llm=llm, input_ids=ids, pixel_values=pixel_values,
|
317 |
+
mark_id=mark_token_id,
|
318 |
+
mark_feats=mark_embeddings, region_id=-9999)
|
319 |
|
320 |
+
# mm_inputs['inputs_embeds'] = mm_inputs['inputs_embeds'].to(torch.float16)
|
321 |
+
|
322 |
+
generate_output = llm.generate(
|
323 |
+
**mm_inputs,
|
324 |
+
generation_config=gen_config,
|
325 |
+
streamer=streamer,
|
326 |
+
bos_token_id=tokenizer.bos_token_id,
|
327 |
+
stopping_criteria=stop_criteria,
|
328 |
+
output_hidden_states=True,
|
329 |
+
return_dict_in_generate=True
|
330 |
+
)
|
331 |
+
|
332 |
+
predict = tokenizer.decode(
|
333 |
+
generate_output.sequences[0])
|
334 |
+
|
335 |
+
global_infos.inputs += predict
|
336 |
+
predict = predict.strip()
|
337 |
+
global_infos.n_turn += 1
|
338 |
+
global_infos.inputs += sep
|
339 |
+
if len(generate_output.sequences[0]) >= args.max_new_tokens:
|
340 |
+
print(
|
341 |
+
'Remove the memory of history responses, since '
|
342 |
+
f'it exceeds the length limitation {args.max_new_tokens}.')
|
343 |
+
global_infos.n_turn = 0
|
344 |
+
global_infos.inputs = ''
|
345 |
+
|
346 |
+
hidden_states = generate_output.hidden_states
|
347 |
+
last_hidden_states = [item[-1][0] for item in hidden_states]
|
348 |
+
last_hidden_states = torch.cat(last_hidden_states, dim=0)
|
349 |
+
seg_hidden_states = get_seg_hidden_states(
|
350 |
+
last_hidden_states, generate_output.sequences[0][:-1],
|
351 |
+
seg_id=omg_llava.seg_token_idx
|
352 |
+
)
|
353 |
+
# seg_hidden_states = seg_hidden_states.to(torch.float32)
|
354 |
+
if len(seg_hidden_states) != 0:
|
355 |
+
seg_hidden_states = projector_text2vision(seg_hidden_states)
|
356 |
+
batch_idxs = torch.zeros((seg_hidden_states.shape[0],),
|
357 |
+
dtype=torch.int64).to(seg_hidden_states.device)
|
358 |
+
pred_masks_list = omg_llava.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs)
|
359 |
+
print((pred_masks_list[-1].flatten(2) > 0).sum(-1))
|
360 |
+
print(pred_masks_list[-1].shape)
|
361 |
+
image_mask_show, selected_colors = show_mask_pred(
|
362 |
+
global_infos.image_for_show, pred_masks_list[-1],
|
363 |
+
crop_range = (global_infos.sx, global_infos.ex, global_infos.sy, global_infos.ey)
|
364 |
+
)
|
365 |
+
else:
|
366 |
+
image_mask_show = global_infos.image_for_show.crop(
|
367 |
+
(global_infos.sx, global_infos.sy, global_infos.ex, global_infos.ey))
|
368 |
+
selected_colors = []
|
369 |
+
|
370 |
+
panoptic_show, _ = show_mask_pred(
|
371 |
+
global_infos.image_for_show, global_infos.panoptic_masks,
|
372 |
+
crop_range=(global_infos.sx, global_infos.ex, global_infos.sy, global_infos.ey)
|
373 |
+
)
|
374 |
+
|
375 |
+
predict = process_markdown(predict, selected_colors)
|
376 |
+
# return panoptic_show, image_mask_show, predict
|
377 |
+
return image_mask_show, predict
|
378 |
+
|
379 |
+
def init_models(args):
|
380 |
+
torch.manual_seed(args.seed)
|
381 |
+
|
382 |
+
# parse config
|
383 |
+
if not osp.isfile(args.config):
|
384 |
+
try:
|
385 |
+
args.config = cfgs_name_path[args.config]
|
386 |
+
except KeyError:
|
387 |
+
raise FileNotFoundError(f'Cannot find {args.config}')
|
388 |
+
|
389 |
+
# load config
|
390 |
+
cfg = Config.fromfile(args.config)
|
391 |
+
|
392 |
+
model_name = cfg.model.type if isinstance(cfg.model.type,
|
393 |
+
str) else cfg.model.type.__name__
|
394 |
+
if 'LLaVAModel' or 'OMG' in model_name:
|
395 |
+
cfg.model.pretrained_pth = None
|
396 |
+
|
397 |
+
model = BUILDER.build(cfg.model)
|
398 |
+
|
399 |
+
backend = get_file_backend(args.pth_model)
|
400 |
+
if isinstance(backend, PetrelBackend):
|
401 |
+
from xtuner.utils.fileio import patch_fileio
|
402 |
+
with patch_fileio():
|
403 |
+
state_dict = guess_load_checkpoint(args.pth_model)
|
404 |
+
else:
|
405 |
+
state_dict = guess_load_checkpoint(args.pth_model)
|
406 |
+
|
407 |
+
model.load_state_dict(state_dict, strict=False)
|
408 |
+
print(f'Load PTH model from {args.pth_model}')
|
409 |
+
|
410 |
+
image_processor = cfg.image_processor
|
411 |
+
image_processor_type = image_processor['type']
|
412 |
+
del image_processor['type']
|
413 |
+
image_processor = image_processor_type(**image_processor)
|
414 |
+
|
415 |
+
# build llm
|
416 |
+
quantization_config = None
|
417 |
+
load_in_8bit = False
|
418 |
+
if args.bits == 4:
|
419 |
+
quantization_config = BitsAndBytesConfig(
|
420 |
+
load_in_4bit=True,
|
421 |
+
load_in_8bit=False,
|
422 |
+
llm_int8_threshold=6.0,
|
423 |
+
llm_int8_has_fp16_weight=False,
|
424 |
+
bnb_4bit_compute_dtype=torch.float16,
|
425 |
+
bnb_4bit_use_double_quant=True,
|
426 |
+
bnb_4bit_quant_type='nf4')
|
427 |
+
elif args.bits == 8:
|
428 |
+
load_in_8bit = True
|
429 |
+
model_kwargs = {
|
430 |
+
'quantization_config': quantization_config,
|
431 |
+
'load_in_8bit': load_in_8bit,
|
432 |
+
'device_map': 'auto',
|
433 |
+
'offload_folder': args.offload_folder,
|
434 |
+
'trust_remote_code': True,
|
435 |
+
'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
|
436 |
+
}
|
437 |
+
|
438 |
+
inner_thoughts_open = False
|
439 |
+
calculate_open = False
|
440 |
+
solve_open = False
|
441 |
+
search_open = False
|
442 |
+
|
443 |
+
# build llm
|
444 |
+
llm = model.llm
|
445 |
+
tokenizer = model.tokenizer
|
446 |
+
|
447 |
+
model.cuda()
|
448 |
+
model.eval()
|
449 |
+
llm.eval()
|
450 |
+
visual_encoder = model.visual_encoder
|
451 |
+
projector = model.projector
|
452 |
+
projector_text2vision = model.projector_text2vision
|
453 |
+
|
454 |
+
visual_encoder.eval()
|
455 |
+
projector.eval()
|
456 |
+
projector_text2vision.eval()
|
457 |
+
|
458 |
+
return model, llm, tokenizer, image_processor, visual_encoder, projector, projector_text2vision
|
459 |
+
|
460 |
+
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
|
461 |
+
seg_mask = output_ids == seg_id
|
462 |
+
n_out = len(seg_mask)
|
463 |
+
print(output_ids)
|
464 |
+
return hidden_states[-n_out:][seg_mask]
|
465 |
+
|
466 |
+
class global_infos:
|
467 |
+
inputs = ''
|
468 |
+
n_turn = 0
|
469 |
+
image_width = 0
|
470 |
+
image_height = 0
|
471 |
+
|
472 |
+
image_for_show = None
|
473 |
+
pixel_values = None
|
474 |
+
panoptic_masks = None
|
475 |
+
|
476 |
+
sx, sy, ex, ey = 0, 0 ,1024, 1024
|
477 |
+
|
478 |
+
point_prompts = []
|
479 |
+
box_prompts = []
|
480 |
+
mask_prompts = []
|
481 |
|
482 |
if __name__ == "__main__":
|
483 |
+
# get parse args and set models
|
484 |
+
args = parse_args(sys.argv[1:])
|
485 |
+
|
486 |
+
omg_llava, llm, tokenizer, image_processor, visual_encoder, projector, projector_text2vision = \
|
487 |
+
init_models(args)
|
488 |
+
|
489 |
+
stop_words = args.stop_words
|
490 |
+
sep = ''
|
491 |
+
if args.prompt_template:
|
492 |
+
template = PROMPT_TEMPLATE[args.prompt_template]
|
493 |
+
stop_words += template.get('STOP_WORDS', [])
|
494 |
+
sep = template.get('SEP', '')
|
495 |
+
stop_criteria = get_stop_criteria(
|
496 |
+
tokenizer=tokenizer, stop_words=stop_words)
|
497 |
+
|
498 |
+
if args.no_streamer:
|
499 |
+
streamer = None
|
500 |
+
else:
|
501 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True)
|
502 |
+
|
503 |
+
gen_config = GenerationConfig(
|
504 |
+
max_new_tokens=args.max_new_tokens,
|
505 |
+
do_sample=args.temperature > 0,
|
506 |
+
temperature=args.temperature,
|
507 |
+
top_p=args.top_p,
|
508 |
+
top_k=args.top_k,
|
509 |
+
repetition_penalty=args.repetition_penalty,
|
510 |
+
eos_token_id=tokenizer.eos_token_id,
|
511 |
+
pad_token_id=tokenizer.pad_token_id
|
512 |
+
if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
|
513 |
+
)
|
514 |
+
|
515 |
+
demo = gr.Interface(
|
516 |
+
inference, inputs=[gr.Textbox(lines=1, placeholder=None, label="Text Instruction"), ImagePrompter(
|
517 |
+
type='filepath', label='Input Image (Please click points or draw bboxes)', interactive=True,
|
518 |
+
elem_id='image_upload', height=360, visible=True, render=True
|
519 |
+
),
|
520 |
+
gr.Checkbox(label="Follow up Question")],
|
521 |
+
outputs=[
|
522 |
+
# gr.Image(type="pil", label="Panoptic Segmentation", height=360),
|
523 |
+
gr.Image(type="pil", label="Output Image"),
|
524 |
+
gr.Markdown()],
|
525 |
+
theme=gr.themes.Soft(), allow_flagging="auto", )
|
526 |
+
|
527 |
+
demo.queue()
|
528 |
+
demo.launch(share=True)
|
529 |
+
|
530 |
+
|