watchtowerss commited on
Commit
5da584d
1 Parent(s): 94dd0a9

Upload 65 files

Browse files
Files changed (2) hide show
  1. app.py +11 -39
  2. requirements.txt +5 -1
app.py CHANGED
@@ -13,13 +13,7 @@ import requests
13
  import json
14
  import torchvision
15
  import torch
16
- from tools.interact_tools import SamControler
17
- from tracker.base_tracker import BaseTracker
18
  from tools.painter import mask_painter
19
- try:
20
- from mmcv.cnn import ConvModule
21
- except:
22
- os.system("mim install mmcv")
23
 
24
  # download checkpoints
25
  def download_checkpoint(url, folder, filename):
@@ -206,7 +200,6 @@ def show_mask(video_state, interactive_state, mask_dropdown):
206
 
207
  # tracking vos
208
  def vos_tracking_video(video_state, interactive_state, mask_dropdown):
209
-
210
  model.xmem.clear_memory()
211
  if interactive_state["track_end_number"]:
212
  following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
@@ -226,8 +219,6 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
226
  template_mask = video_state["masks"][video_state["select_frame_number"]]
227
  fps = video_state["fps"]
228
  masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
229
- # clear GPU memory
230
- model.xmem.clear_memory()
231
 
232
  if interactive_state["track_end_number"]:
233
  video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
@@ -267,7 +258,6 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
267
 
268
  # inpaint
269
  def inpaint_video(video_state, interactive_state, mask_dropdown):
270
-
271
  frames = np.asarray(video_state["origin_images"])
272
  fps = video_state["fps"]
273
  inpaint_masks = np.asarray(video_state["masks"])
@@ -314,44 +304,27 @@ def generate_video_from_frames(frames, output_path, fps=30):
314
  torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
315
  return output_path
316
 
317
-
318
- # args, defined in track_anything.py
319
- args = parse_augment()
320
-
321
  # check and download checkpoints if needed
322
- SAM_checkpoint_dict = {
323
- 'vit_h': "sam_vit_h_4b8939.pth",
324
- 'vit_l': "sam_vit_l_0b3195.pth",
325
- "vit_b": "sam_vit_b_01ec64.pth"
326
- }
327
- SAM_checkpoint_url_dict = {
328
- 'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
329
- 'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
330
- 'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
331
- }
332
- sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type]
333
- sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type]
334
  xmem_checkpoint = "XMem-s012.pth"
335
  xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
336
  e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
337
  e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
338
 
339
-
340
  folder ="./checkpoints"
341
- SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
342
  xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
343
  e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
344
-
 
 
 
 
345
 
346
  # initialize sam, xmem, e2fgvi models
347
  model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
348
 
349
-
350
- title = """<p><h1 align="center">Track-Anything</h1></p>
351
- """
352
- description = """<p>Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting. I To use it, simply upload your video, or click one of the examples to load them. Code: <a href="https://github.com/gaomingqi/Track-Anything">https://github.com/gaomingqi/Track-Anything</a> <a href="https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
353
-
354
-
355
  with gr.Blocks() as iface:
356
  """
357
  state for
@@ -383,8 +356,7 @@ with gr.Blocks() as iface:
383
  "fps": 30
384
  }
385
  )
386
- gr.Markdown(title)
387
- gr.Markdown(description)
388
  with gr.Row():
389
 
390
  # for user video input
@@ -393,7 +365,7 @@ with gr.Blocks() as iface:
393
  video_input = gr.Video(autosize=True)
394
  with gr.Column():
395
  video_info = gr.Textbox()
396
- resize_info = gr.Textbox(value="If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
397
  Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
398
  resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
399
 
@@ -562,7 +534,7 @@ with gr.Blocks() as iface:
562
  # cache_examples=True,
563
  )
564
  iface.queue(concurrency_count=1)
565
- iface.launch(debug=True, enable_queue=True)
566
 
567
 
568
 
 
13
  import json
14
  import torchvision
15
  import torch
 
 
16
  from tools.painter import mask_painter
 
 
 
 
17
 
18
  # download checkpoints
19
  def download_checkpoint(url, folder, filename):
 
200
 
201
  # tracking vos
202
  def vos_tracking_video(video_state, interactive_state, mask_dropdown):
 
203
  model.xmem.clear_memory()
204
  if interactive_state["track_end_number"]:
205
  following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
 
219
  template_mask = video_state["masks"][video_state["select_frame_number"]]
220
  fps = video_state["fps"]
221
  masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
 
 
222
 
223
  if interactive_state["track_end_number"]:
224
  video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
 
258
 
259
  # inpaint
260
  def inpaint_video(video_state, interactive_state, mask_dropdown):
 
261
  frames = np.asarray(video_state["origin_images"])
262
  fps = video_state["fps"]
263
  inpaint_masks = np.asarray(video_state["masks"])
 
304
  torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
305
  return output_path
306
 
 
 
 
 
307
  # check and download checkpoints if needed
308
+ SAM_checkpoint = "sam_vit_h_4b8939.pth"
309
+ sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
 
 
 
 
 
 
 
 
 
 
310
  xmem_checkpoint = "XMem-s012.pth"
311
  xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
312
  e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
313
  e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
314
 
 
315
  folder ="./checkpoints"
316
+ SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
317
  xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
318
  e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
319
+ # args, defined in track_anything.py
320
+ args = parse_augment()
321
+ # args.port = 12315
322
+ # args.device = "cuda:2"
323
+ # args.mask_save = True
324
 
325
  # initialize sam, xmem, e2fgvi models
326
  model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
327
 
 
 
 
 
 
 
328
  with gr.Blocks() as iface:
329
  """
330
  state for
 
356
  "fps": 30
357
  }
358
  )
359
+
 
360
  with gr.Row():
361
 
362
  # for user video input
 
365
  video_input = gr.Video(autosize=True)
366
  with gr.Column():
367
  video_info = gr.Textbox()
368
+ video_info = gr.Textbox(value="If you want to use the inpaint function, it is best to download and use a machine with more VRAM locally. \
369
  Alternatively, you can use the resize ratio slider to scale down the original image to around 360P resolution for faster processing.")
370
  resize_ratio_slider = gr.Slider(minimum=0.02, maximum=1, step=0.02, value=1, label="Resize ratio", visible=True)
371
 
 
534
  # cache_examples=True,
535
  )
536
  iface.queue(concurrency_count=1)
537
+ iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
538
 
539
 
540
 
requirements.txt CHANGED
@@ -10,6 +10,10 @@ gradio==3.25.0
10
  opencv-python
11
  pycocotools
12
  matplotlib
 
 
 
13
  pyyaml
14
  av
15
- openmim
 
 
10
  opencv-python
11
  pycocotools
12
  matplotlib
13
+ onnxruntime
14
+ onnx
15
+ metaseg==0.6.1
16
  pyyaml
17
  av
18
+ mmcv-full
19
+ mmengine