chenlin commited on
Commit
49bae8f
1 Parent(s): f0b9014

support multi-mode infer

Browse files
Files changed (3) hide show
  1. .gitattributes +2 -0
  2. SimHei.ttf +3 -0
  3. app.py +252 -21
.gitattributes CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.mp4 filter=lfs diff=lfs merge=lfs -text
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ *.ttf filter=lfs diff=lfs merge=lfs -text
38
+
SimHei.ttf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6625b9b91a5054faa413b69d171020c3f6d9a872345d8a3c5e3df61809291b7f
3
+ size 10043912
app.py CHANGED
@@ -1,10 +1,18 @@
 
1
  import os
2
  import shutil
3
  import tempfile
 
4
 
5
- import spaces
6
  import gradio as gr
 
7
  import torch
 
 
 
 
 
 
8
 
9
  title_markdown = ("""
10
  <div style="display: flex; justify-content: flex-start; align-items: center; text-align: center;">
@@ -33,31 +41,261 @@ The service is a research preview intended for non-commercial use only, subject
33
  """)
34
 
35
 
36
- model_path = ''
37
- device = 'cuda'
38
- load_8bit = False
39
- load_4bit = False
40
- dtype = torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  @spaces.GPU(duration=60)
44
- def generate_slidingcaptioning(video):
45
- return 'text'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  @spaces.GPU(duration=60)
48
- def generate_fastcaptioning(video):
49
- return 'text'
 
 
 
 
 
 
 
 
50
 
51
  @spaces.GPU(duration=60)
52
  def generate_promptrecaptioning(text):
53
- return text
54
-
 
 
 
 
 
55
  def save_video_to_local(video_path):
56
  filename = os.path.join('temp', next(
57
  tempfile._get_candidate_names()) + '.mp4')
58
  shutil.copyfile(video_path, filename)
59
  return filename
60
 
 
61
  with gr.Blocks(title='ShareCaptioner-Video', theme=gr.themes.Default(), css=block_css) as demo:
62
  gr.Markdown(title_markdown)
63
  state = gr.State()
@@ -65,14 +303,13 @@ with gr.Blocks(title='ShareCaptioner-Video', theme=gr.themes.Default(), css=bloc
65
  first_run = gr.State()
66
 
67
  with gr.Row():
68
- gr.Markdown("### The ShareCaptioner-Video is a Four-in-One exceptional video captioning model with the following capabilities:\n1. Fast captioning, 2. Sliding Captioning, 3. Clip Summarizing, 4. Prompt Re-Captioning")
69
  with gr.Row():
70
  gr.Markdown("(THE DEMO OF \"Clip Summarizing\" IS COMING SOON...)")
71
  with gr.Row():
72
  with gr.Column(scale=6):
73
  with gr.Row():
74
  video = gr.Video(label="Input Video")
75
- cur_dir = os.path.dirname(os.path.abspath(__file__))
76
  with gr.Row():
77
  textbox = gr.Textbox(
78
  show_label=False, placeholder="Input Text", container=False
@@ -97,14 +334,8 @@ with gr.Blocks(title='ShareCaptioner-Video', theme=gr.themes.Default(), css=bloc
97
  )
98
  gr.Markdown(learn_more_markdown)
99
 
100
- submit_btn_sc.click(generate_slidingcaptioning, [video],[textbox_out])
101
  submit_btn_fc.click(generate_fastcaptioning, [video], [textbox_out])
102
  submit_btn_pr.click(generate_promptrecaptioning, [textbox], [textbox_out])
103
 
104
- ### for local launch
105
- # demo.launch(server_name="0.0.0.0",
106
- # server_port=28358,
107
- # share=True)
108
-
109
- ### for huggingface launch
110
  demo.launch()
 
1
+ import base64
2
  import os
3
  import shutil
4
  import tempfile
5
+ from io import BytesIO
6
 
 
7
  import gradio as gr
8
+ import numpy as np
9
  import torch
10
+ import torchvision.transforms as transforms
11
+ from decord import VideoReader
12
+ from PIL import Image, ImageDraw, ImageFont
13
+ from transformers import AutoModel, AutoTokenizer
14
+
15
+ import spaces
16
 
17
  title_markdown = ("""
18
  <div style="display: flex; justify-content: flex-start; align-items: center; text-align: center;">
 
41
  """)
42
 
43
 
44
+ new_path = 'Lin-Chen/ShareCaptioner-Video'
45
+ tokenizer = AutoTokenizer.from_pretrained(new_path, trust_remote_code=True)
46
+ model = AutoModel.from_pretrained(
47
+ new_path, torch_dtype=torch.bfloat16, trust_remote_code=True).cuda().eval()
48
+ model.cuda()
49
+ model.tokenizer = tokenizer
50
+
51
+
52
+ def padding_336(b, pad=336):
53
+ width, height = b.size
54
+ tar = int(np.ceil(height / pad) * pad)
55
+ top_padding = int((tar - height)/2)
56
+ bottom_padding = tar - height - top_padding
57
+ left_padding = 0
58
+ right_padding = 0
59
+ b = transforms.functional.pad(
60
+ b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255, 255, 255])
61
+
62
+ return b
63
+
64
+
65
+ def HD_transform(img, hd_num=25):
66
+ width, height = img.size
67
+ trans = False
68
+ if width < height:
69
+ img = img.transpose(Image.TRANSPOSE)
70
+ trans = True
71
+ width, height = img.size
72
+ ratio = (width / height)
73
+ scale = 1
74
+ while scale*np.ceil(scale/ratio) <= hd_num:
75
+ scale += 1
76
+ scale -= 1
77
+ new_w = int(scale * 336)
78
+ new_h = int(new_w / ratio)
79
+
80
+ img = transforms.functional.resize(img, [new_h, new_w],)
81
+ img = padding_336(img, 336)
82
+ width, height = img.size
83
+ if trans:
84
+ img = img.transpose(Image.TRANSPOSE)
85
+
86
+ return img
87
+
88
+
89
+ def get_seq_frames(total_num_frames, desired_num_frames, start=None, end=None):
90
+ if start is None:
91
+ assert end is None
92
+ start, end = 0, total_num_frames
93
+ print(f"{start=}, {end=}")
94
+ desired_num_frames -= 2
95
+ end = min(total_num_frames, end)
96
+ start = max(start, 0)
97
+ seg_size = float((end - start)) / desired_num_frames
98
+ seq = [start]
99
+
100
+ for i in range(desired_num_frames):
101
+ s = int(np.round(seg_size * i))
102
+ e = int(np.round(seg_size * (i + 1)))
103
+ seq.append(min(int(start + (s + e) // 2), total_num_frames-1))
104
+ return seq + [end-1]
105
+
106
+
107
+ def model_gen(model, text, images, need_bos=True, hd_num=25, max_new_token=2048, beam=3, do_sample=False):
108
+ pt1 = 0
109
+ embeds = []
110
+ im_mask = []
111
+ if images is None:
112
+ images = []
113
+ images_loc = []
114
+ else:
115
+ images = [images]
116
+ images_loc = [0]
117
+ for i, pts in enumerate(images_loc + [len(text)]):
118
+ subtext = text[pt1:pts]
119
+ if need_bos or len(subtext) > 0:
120
+ text_embeds = model.encode_text(
121
+ subtext, add_special_tokens=need_bos)
122
+ embeds.append(text_embeds)
123
+ im_mask.append(torch.zeros(text_embeds.shape[:2]).cuda())
124
+ need_bos = False
125
+ if i < len(images):
126
+ try:
127
+ image = Image.open(images[i]).convert('RGB')
128
+ except:
129
+ image = images[i].convert('RGB')
130
+
131
+ image = HD_transform(image, hd_num=hd_num)
132
+ image = model.vis_processor(image).unsqueeze(0).cuda()
133
+ image_embeds = model.encode_img(image)
134
+ print(image_embeds.shape)
135
+ embeds.append(image_embeds)
136
+ im_mask.append(torch.ones(image_embeds.shape[:2]).cuda())
137
+ pt1 = pts
138
+ embeds = torch.cat(embeds, dim=1)
139
+ im_mask = torch.cat(im_mask, dim=1)
140
+ im_mask = im_mask.bool()
141
+ outputs = model.generate(inputs_embeds=embeds, im_mask=im_mask,
142
+ temperature=1.0, max_new_tokens=max_new_token, num_beams=beam,
143
+ do_sample=False, repetition_penalty=1.00)
144
+
145
+ output_token = outputs[0]
146
+ if output_token[0] == 0 or output_token[0] == 1:
147
+ output_token = output_token[1:]
148
+ output_text = model.tokenizer.decode(
149
+ output_token, add_special_tokens=False)
150
+ output_text = output_text.split('[UNUSED_TOKEN_145]')[0].strip()
151
+ return output_text
152
+
153
+
154
+ def img_process(imgs):
155
+ new_w = 0
156
+ new_h = 0
157
+ for im in imgs:
158
+ w, h = im.size
159
+ new_w = max(new_w, w)
160
+ new_h += h + 20
161
+ pad = max(new_w // 4, 100)
162
+ new_w += 20
163
+ new_h += 20
164
+ font = ImageFont.truetype("SimHei.ttf", pad // 5)
165
+ new_img = Image.new('RGB', (new_w + pad, new_h), 'white')
166
+ draw = ImageDraw.Draw(new_img)
167
+ curr_h = 10
168
+ for idx, im in enumerate(imgs):
169
+ w, h = im.size
170
+ new_img.paste(im, (pad, curr_h))
171
+ draw.text((0, curr_h + h // 2),
172
+ f'<IMAGE {idx}>', font=font, fill='black')
173
+ if idx + 1 < len(imgs):
174
+ draw.line([(0, curr_h + h + 10), (new_w+pad,
175
+ curr_h + h + 10)], fill='black', width=2)
176
+ curr_h += h + 20
177
+ return new_img
178
+
179
+
180
+ def load_quota_video(vis_path, start=None, end=None):
181
+ vr = VideoReader(vis_path)
182
+ total_frame_num = len(vr)
183
+ fps = vr.get_avg_fps()
184
+ if start is not None:
185
+ assert end is not None
186
+ start_frame = int(start * fps)
187
+ end_frame = min(int(end * fps), total_frame_num)
188
+ else:
189
+ start_frame = 0
190
+ end_frame = total_frame_num
191
+ interval = int(2 * fps)
192
+ frame_idx = list(range(start_frame, end_frame, interval))
193
+ img_array = vr.get_batch(frame_idx).asnumpy()
194
+ num_frm, H, W, _ = img_array.shape
195
+ img_array = img_array.reshape(
196
+ (1, num_frm, img_array.shape[-3], img_array.shape[-2], img_array.shape[-1]))
197
+ clip_imgs = []
198
+ for j in range(num_frm):
199
+ clip_imgs.append(Image.fromarray(img_array[0, j]))
200
+ return clip_imgs
201
+
202
+
203
+ def resize_image(image_path, max_size=1024):
204
+ with Image.open(image_path) as img:
205
+ width, height = img.size
206
+ if width > max_size or height > max_size:
207
+ if width > height:
208
+ new_width = max_size
209
+ new_height = int(height * (max_size / width))
210
+ else:
211
+ new_height = max_size
212
+ new_width = int(width * (max_size / height))
213
+ else:
214
+ new_width = width
215
+ new_height = height
216
+ resized_img = img.resize((new_width, new_height))
217
+ print(f"resized_img_size: {resized_img.size}")
218
+ return resized_img
219
+
220
+
221
+ def encode_resized_image(image_path, max_size=1024):
222
+ resized_img = resize_image(image_path, max_size)
223
+ try:
224
+ with BytesIO() as buffer:
225
+ resized_img.save(buffer, format="JPEG")
226
+ return base64.b64encode(buffer.getvalue()).decode('utf-8')
227
+ except:
228
+ with BytesIO() as buffer:
229
+ rgb_img = resized_img.convert('RGB')
230
+ rgb_img.save(buffer, format="JPEG")
231
+ return base64.b64encode(buffer.getvalue()).decode('utf-8')
232
 
233
 
234
  @spaces.GPU(duration=60)
235
+ def generate_slidingcaptioning(video_path):
236
+ imgs = load_quota_video(video_path)
237
+ q = 'This is the first frame of a video, describe it in detail.'
238
+ query = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
239
+ img = imgs[0]
240
+ with torch.cuda.amp.autocast():
241
+ response = model_gen(model, query, img, hd_num=9)
242
+ print(response)
243
+ responses = [response]
244
+ images = [img]
245
+ for idx in range(len(imgs)-1):
246
+ image1 = imgs[idx]
247
+ image2 = imgs[idx+1]
248
+ prompt = "Here are the Video frame {} at {}.00 Second(s) and Video frame {} at {}.00 Second(s) of a video, describe what happend between them. What happend before is: {}".format(
249
+ idx, int(idx*2), idx+1, int((idx+1)*2), response)
250
+ width, height = image1.size
251
+ new_img = Image.new('RGB', (width, 2*height+50), 'white')
252
+ new_img.paste(image1, (0, 0))
253
+ new_img.paste(image2, (0, height+50))
254
+ query = f'[UNUSED_TOKEN_146]user\n{prompt}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
255
+ with torch.cuda.amp.autocast():
256
+ response = model_gen(model, query, new_img, hd_num=9)
257
+ responses.append(response)
258
+ images.append(new_img)
259
+ prompt = 'Summarize the following per frame descriptions:\n'
260
+ for idx, txt in enumerate(responses):
261
+ prompt += 'Video frame {} at {}.00 Second(s) description: {}\n'.format(
262
+ idx+1, idx*2, txt)
263
+ query = f'[UNUSED_TOKEN_146]user\n{prompt}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
264
+ print(query)
265
+ with torch.cuda.amp.autocast():
266
+ summ = model_gen(model, query, None, hd_num=16)
267
+ print(summ)
268
+ return summ
269
+
270
 
271
  @spaces.GPU(duration=60)
272
+ def generate_fastcaptioning(video_path):
273
+ q = 'Here are a few key frames of a video, discribe this video in detail.'
274
+ query = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
275
+ imgs = load_quota_video(video_path, start=start, end=end)
276
+ img = img_process(imgs)
277
+ with torch.cuda.amp.autocast():
278
+ response = model_gen(model, query, img, hd_num=16,
279
+ do_sample=False, beam=3)
280
+ return response
281
+
282
 
283
  @spaces.GPU(duration=60)
284
  def generate_promptrecaptioning(text):
285
+ q = f'Translate this brief generation prompt into a detailed caption: {text}'
286
+ query = f'[UNUSED_TOKEN_146]user\n{q}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
287
+ with torch.cuda.amp.autocast():
288
+ response = model_gen(model, query, None)
289
+ return response
290
+
291
+
292
  def save_video_to_local(video_path):
293
  filename = os.path.join('temp', next(
294
  tempfile._get_candidate_names()) + '.mp4')
295
  shutil.copyfile(video_path, filename)
296
  return filename
297
 
298
+
299
  with gr.Blocks(title='ShareCaptioner-Video', theme=gr.themes.Default(), css=block_css) as demo:
300
  gr.Markdown(title_markdown)
301
  state = gr.State()
 
303
  first_run = gr.State()
304
 
305
  with gr.Row():
306
+ gr.Markdown("### The ShareCaptioner-Video is a Four-in-One exceptional video captioning model with the following capabilities:\n1. Fast captioning, 2. Sliding Captioning, 3. Clip Summarizing, 4. Prompt Re-Captioning")
307
  with gr.Row():
308
  gr.Markdown("(THE DEMO OF \"Clip Summarizing\" IS COMING SOON...)")
309
  with gr.Row():
310
  with gr.Column(scale=6):
311
  with gr.Row():
312
  video = gr.Video(label="Input Video")
 
313
  with gr.Row():
314
  textbox = gr.Textbox(
315
  show_label=False, placeholder="Input Text", container=False
 
334
  )
335
  gr.Markdown(learn_more_markdown)
336
 
337
+ submit_btn_sc.click(generate_slidingcaptioning, [video], [textbox_out])
338
  submit_btn_fc.click(generate_fastcaptioning, [video], [textbox_out])
339
  submit_btn_pr.click(generate_promptrecaptioning, [textbox], [textbox_out])
340
 
 
 
 
 
 
 
341
  demo.launch()