wondervictor commited on
Commit
95bdec8
·
1 Parent(s): 5e769e6

update sam2

Browse files
Files changed (2) hide show
  1. app.py +110 -38
  2. app_video.py → app.py.bak +38 -110
app.py CHANGED
@@ -7,15 +7,18 @@ import timm
7
  print("installed", timm.__version__)
8
  import gradio as gr
9
  from inference import sam_preprocess, beit3_preprocess
10
- from model.evf_sam import EvfSamModel
 
11
  from transformers import AutoTokenizer
12
  import torch
 
13
  import numpy as np
14
  import sys
15
  import os
 
16
 
17
- version = "YxZhang/evf-sam"
18
- model_type = "ori"
19
 
20
  tokenizer = AutoTokenizer.from_pretrained(
21
  version,
@@ -26,27 +29,40 @@ tokenizer = AutoTokenizer.from_pretrained(
26
  kwargs = {
27
  "torch_dtype": torch.half,
28
  }
29
- model = EvfSamModel.from_pretrained(version, low_cpu_mem_usage=True,
30
- **kwargs).eval()
31
- model.to('cuda')
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  @spaces.GPU
35
  @torch.no_grad()
36
- def pred(image_np, prompt):
37
  original_size_list = [image_np.shape[:2]]
38
 
39
- image_beit = beit3_preprocess(image_np, 224).to(dtype=model.dtype,
40
- device=model.device)
41
 
42
  image_sam, resize_shape = sam_preprocess(image_np, model_type=model_type)
43
- image_sam = image_sam.to(dtype=model.dtype, device=model.device)
 
44
 
45
  input_ids = tokenizer(
46
- prompt, return_tensors="pt")["input_ids"].to(device=model.device)
47
 
48
  # infer
49
- pred_mask = model.inference(
50
  image_sam.unsqueeze(0),
51
  image_beit.unsqueeze(0),
52
  input_ids,
@@ -61,7 +77,50 @@ def pred(image_np, prompt):
61
  pred_mask[:, :, None].astype(np.uint8) *
62
  np.array([50, 120, 220]) * 0.5)[pred_mask]
63
 
64
- return visualization / 255.0, pred_mask.astype(np.float16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  desc = """
@@ -73,28 +132,41 @@ desc = """
73
  # desc_title_str = '<div align ="center"><img src="assets/logo.jpg" width="20%"><h3> Early Vision-Language Fusion for Text-Prompted Segment Anything Model</h3></div>'
74
  # desc_link_str = '[![arxiv paper](https://img.shields.io/badge/arXiv-Paper-red)](https://arxiv.org/abs/2406.20076)'
75
 
76
- demo = gr.Interface(
77
- fn=pred,
78
- inputs=[
79
- gr.components.Image(type="numpy", label="Image", image_mode="RGB"),
80
- gr.components.Textbox(
81
- label="Prompt",
82
- info=
83
- "Use a phrase or sentence to describe the object you want to segment. Currently we only support English"
84
- )
85
- ],
86
- outputs=[
87
- gr.components.Image(type="numpy", label="visulization"),
88
- gr.components.Image(type="numpy", label="mask")
89
- ],
90
- examples=[["assets/zebra.jpg", "zebra top left"],
91
- ["assets/bus.jpg", "bus going to south common"],
92
- [
93
- "assets/carrots.jpg",
94
- "3carrots in center with ice and greenn leaves"
95
- ]],
96
- title="📷 EVF-SAM: Referring Expression Segmentation",
97
- description=desc,
98
- allow_flagging="never")
99
- # demo.launch()
100
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  print("installed", timm.__version__)
8
  import gradio as gr
9
  from inference import sam_preprocess, beit3_preprocess
10
+ from model.evf_sam2 import EvfSam2Model
11
+ from model.evf_sam2_video import EvfSam2Model as EvfSam2VideoModel
12
  from transformers import AutoTokenizer
13
  import torch
14
+ import cv2
15
  import numpy as np
16
  import sys
17
  import os
18
+ import tqdm
19
 
20
+ version = "YxZhang/evf-sam2"
21
+ model_type = "sam2"
22
 
23
  tokenizer = AutoTokenizer.from_pretrained(
24
  version,
 
29
  kwargs = {
30
  "torch_dtype": torch.half,
31
  }
32
+
33
+ image_model = EvfSam2Model.from_pretrained(version,
34
+ low_cpu_mem_usage=True,
35
+ **kwargs)
36
+ del image_model.visual_model.memory_encoder
37
+ del image_model.visual_model.memory_attention
38
+ image_model = image_model.eval()
39
+ image_model.to('cuda')
40
+
41
+ video_model = EvfSam2VideoModel.from_pretrained(version,
42
+ low_cpu_mem_usage=True,
43
+ **kwargs)
44
+ video_model = video_model.eval()
45
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
46
+ video_model.to('cuda')
47
 
48
 
49
  @spaces.GPU
50
  @torch.no_grad()
51
+ def inference_image(image_np, prompt):
52
  original_size_list = [image_np.shape[:2]]
53
 
54
+ image_beit = beit3_preprocess(image_np, 224).to(dtype=image_model.dtype,
55
+ device=image_model.device)
56
 
57
  image_sam, resize_shape = sam_preprocess(image_np, model_type=model_type)
58
+ image_sam = image_sam.to(dtype=image_model.dtype,
59
+ device=image_model.device)
60
 
61
  input_ids = tokenizer(
62
+ prompt, return_tensors="pt")["input_ids"].to(device=image_model.device)
63
 
64
  # infer
65
+ pred_mask = image_model.inference(
66
  image_sam.unsqueeze(0),
67
  image_beit.unsqueeze(0),
68
  input_ids,
 
77
  pred_mask[:, :, None].astype(np.uint8) *
78
  np.array([50, 120, 220]) * 0.5)[pred_mask]
79
 
80
+ return visualization / 255.0
81
+
82
+
83
+ @spaces.GPU
84
+ @torch.no_grad()
85
+ @torch.autocast(device_type="cuda", dtype=torch.float16)
86
+ def inference_video(video_path, prompt):
87
+
88
+ os.system("rm -rf demo_temp")
89
+ os.makedirs("demo_temp/input_frames", exist_ok=True)
90
+ os.system(
91
+ "ffmpeg -i {} -q:v 2 -start_number 0 demo_temp/input_frames/'%05d.jpg'"
92
+ .format(video_path))
93
+ input_frames = sorted(os.listdir("demo_temp/input_frames"))
94
+ image_np = cv2.imread("demo_temp/input_frames/00000.jpg")
95
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
96
+
97
+ height, width, channels = image_np.shape
98
+
99
+ image_beit = beit3_preprocess(image_np, 224).to(dtype=video_model.dtype,
100
+ device=video_model.device)
101
+
102
+ input_ids = tokenizer(
103
+ prompt, return_tensors="pt")["input_ids"].to(device=video_model.device)
104
+
105
+ # infer
106
+ output = video_model.inference(
107
+ "demo_temp/input_frames",
108
+ image_beit.unsqueeze(0),
109
+ input_ids,
110
+ )
111
+ # save visualization
112
+ video_writer = cv2.VideoWriter("demo_temp/out.mp4", fourcc, 30,
113
+ (width, height))
114
+ pbar = tqdm(input_frames)
115
+ pbar.set_description("generating video: ")
116
+ for i, file in enumerate(pbar):
117
+ img = cv2.imread(os.path.join("demo_temp/input_frames", file))
118
+ vis = img + np.array([0, 0, 128]) * output[i][1].transpose(1, 2, 0)
119
+ vis = np.clip(vis, 0, 255)
120
+ vis = np.uint8(vis)
121
+ video_writer.write(vis)
122
+ video_writer.release()
123
+ return "demo_temp/out.mp4"
124
 
125
 
126
  desc = """
 
132
  # desc_title_str = '<div align ="center"><img src="assets/logo.jpg" width="20%"><h3> Early Vision-Language Fusion for Text-Prompted Segment Anything Model</h3></div>'
133
  # desc_link_str = '[![arxiv paper](https://img.shields.io/badge/arXiv-Paper-red)](https://arxiv.org/abs/2406.20076)'
134
 
135
+ with gr.Blocks() as demo:
136
+ gr.Markdown(desc)
137
+ with gr.Tab(label="EVF-SAM-2-Image"):
138
+ with gr.Row():
139
+ input_image = gr.Image(type='numpy',
140
+ label='Input Image',
141
+ image_mode='RGB')
142
+ output_image = gr.Image(type='numpy', label='Output Image')
143
+ with gr.Row():
144
+ image_prompt = gr.Textbox(
145
+ label="Prompt",
146
+ info=
147
+ "Use a phrase or sentence to describe the object you want to segment. Currently we only support English"
148
+ )
149
+ submit_image = gr.Button(value='Submit',
150
+ scale=1,
151
+ variant='primary')
152
+ with gr.Tab(label="EVF-SAM-2-Video"):
153
+ with gr.Row():
154
+ input_video = gr.Video(label='Input Video')
155
+ output_video = gr.Video(label='Output Video')
156
+ with gr.Row():
157
+ video_prompt = gr.Textbox(
158
+ label="Prompt",
159
+ info=
160
+ "Use a phrase or sentence to describe the object you want to segment. Currently we only support English"
161
+ )
162
+ submit_video = gr.Button(value='Submit',
163
+ scale=1,
164
+ variant='primary')
165
+
166
+ submit_image.click(fn=inference_image,
167
+ inputs=[input_image, image_prompt],
168
+ outputs=output_image)
169
+ submit_video.click(fn=inference_video,
170
+ inputs=[input_video, video_prompt],
171
+ outputs=output_video)
172
+ demo.launch(show_error=True)
app_video.py → app.py.bak RENAMED
@@ -7,18 +7,15 @@ import timm
7
  print("installed", timm.__version__)
8
  import gradio as gr
9
  from inference import sam_preprocess, beit3_preprocess
10
- from model.evf_sam2 import EvfSam2Model
11
- from model.evf_sam2_video import EvfSam2Model as EvfSam2VideoModel
12
  from transformers import AutoTokenizer
13
  import torch
14
- import cv2
15
  import numpy as np
16
  import sys
17
  import os
18
- import tqdm
19
 
20
- version = "YxZhang/evf-sam2"
21
- model_type = "sam2"
22
 
23
  tokenizer = AutoTokenizer.from_pretrained(
24
  version,
@@ -29,40 +26,27 @@ tokenizer = AutoTokenizer.from_pretrained(
29
  kwargs = {
30
  "torch_dtype": torch.half,
31
  }
32
-
33
- image_model = EvfSam2Model.from_pretrained(version,
34
- low_cpu_mem_usage=True,
35
- **kwargs)
36
- del image_model.visual_model.memory_encoder
37
- del image_model.visual_model.memory_attention
38
- image_model = image_model.eval()
39
- image_model.to('cuda')
40
-
41
- video_model = EvfSam2VideoModel.from_pretrained(version,
42
- low_cpu_mem_usage=True,
43
- **kwargs)
44
- video_model = video_model.eval()
45
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
46
- video_model.to('cuda')
47
 
48
 
49
  @spaces.GPU
50
  @torch.no_grad()
51
- def inference_image(image_np, prompt):
52
  original_size_list = [image_np.shape[:2]]
53
 
54
- image_beit = beit3_preprocess(image_np, 224).to(dtype=image_model.dtype,
55
- device=image_model.device)
56
 
57
  image_sam, resize_shape = sam_preprocess(image_np, model_type=model_type)
58
- image_sam = image_sam.to(dtype=image_model.dtype,
59
- device=image_model.device)
60
 
61
  input_ids = tokenizer(
62
- prompt, return_tensors="pt")["input_ids"].to(device=image_model.device)
63
 
64
  # infer
65
- pred_mask = image_model.inference(
66
  image_sam.unsqueeze(0),
67
  image_beit.unsqueeze(0),
68
  input_ids,
@@ -77,50 +61,7 @@ def inference_image(image_np, prompt):
77
  pred_mask[:, :, None].astype(np.uint8) *
78
  np.array([50, 120, 220]) * 0.5)[pred_mask]
79
 
80
- return visualization / 255.0
81
-
82
-
83
- @spaces.GPU
84
- @torch.no_grad()
85
- @torch.autocast(device_type="cuda", dtype=torch.float16)
86
- def inference_video(video_path, prompt):
87
-
88
- os.system("rm -rf demo_temp")
89
- os.makedirs("demo_temp/input_frames", exist_ok=True)
90
- os.system(
91
- "ffmpeg -i {} -q:v 2 -start_number 0 demo_temp/input_frames/'%05d.jpg'"
92
- .format(video_path))
93
- input_frames = sorted(os.listdir("demo_temp/input_frames"))
94
- image_np = cv2.imread("demo_temp/input_frames/00000.jpg")
95
- image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
96
-
97
- height, width, channels = image_np.shape
98
-
99
- image_beit = beit3_preprocess(image_np, 224).to(dtype=video_model.dtype,
100
- device=video_model.device)
101
-
102
- input_ids = tokenizer(
103
- prompt, return_tensors="pt")["input_ids"].to(device=video_model.device)
104
-
105
- # infer
106
- output = video_model.inference(
107
- "demo_temp/input_frames",
108
- image_beit.unsqueeze(0),
109
- input_ids,
110
- )
111
- # save visualization
112
- video_writer = cv2.VideoWriter("demo_temp/out.mp4", fourcc, 30,
113
- (width, height))
114
- pbar = tqdm(input_frames)
115
- pbar.set_description("generating video: ")
116
- for i, file in enumerate(pbar):
117
- img = cv2.imread(os.path.join("demo_temp/input_frames", file))
118
- vis = img + np.array([0, 0, 128]) * output[i][1].transpose(1, 2, 0)
119
- vis = np.clip(vis, 0, 255)
120
- vis = np.uint8(vis)
121
- video_writer.write(vis)
122
- video_writer.release()
123
- return "demo_temp/out.mp4"
124
 
125
 
126
  desc = """
@@ -132,41 +73,28 @@ desc = """
132
  # desc_title_str = '<div align ="center"><img src="assets/logo.jpg" width="20%"><h3> Early Vision-Language Fusion for Text-Prompted Segment Anything Model</h3></div>'
133
  # desc_link_str = '[![arxiv paper](https://img.shields.io/badge/arXiv-Paper-red)](https://arxiv.org/abs/2406.20076)'
134
 
135
- with gr.Blocks() as demo:
136
- gr.Markdown(desc)
137
- with gr.Tab(label="EVF-SAM-2-Image"):
138
- with gr.Row():
139
- input_image = gr.Image(type='numpy',
140
- label='Input Image',
141
- image_mode='RGB')
142
- output_image = gr.Image(type='numpy', label='Output Image')
143
- with gr.Row():
144
- image_prompt = gr.Textbox(
145
- label="Prompt",
146
- info=
147
- "Use a phrase or sentence to describe the object you want to segment. Currently we only support English"
148
- )
149
- submit_image = gr.Button(value='Submit',
150
- scale=1,
151
- variant='primary')
152
- with gr.Tab(label="EVF-SAM-2-Video"):
153
- with gr.Row():
154
- input_video = gr.Video(label='Input Video')
155
- output_video = gr.Video(label='Output Video')
156
- with gr.Row():
157
- video_prompt = gr.Textbox(
158
- label="Prompt",
159
- info=
160
- "Use a phrase or sentence to describe the object you want to segment. Currently we only support English"
161
- )
162
- submit_video = gr.Button(value='Submit',
163
- scale=1,
164
- variant='primary')
165
-
166
- submit_image.click(fn=inference_image,
167
- inputs=[input_image, image_prompt],
168
- outputs=output_image)
169
- submit_video.click(fn=inference_video,
170
- inputs=[input_video, video_prompt],
171
- outputs=output_video)
172
- demo.launch(show_error=True)
 
7
  print("installed", timm.__version__)
8
  import gradio as gr
9
  from inference import sam_preprocess, beit3_preprocess
10
+ from model.evf_sam import EvfSamModel
 
11
  from transformers import AutoTokenizer
12
  import torch
 
13
  import numpy as np
14
  import sys
15
  import os
 
16
 
17
+ version = "YxZhang/evf-sam"
18
+ model_type = "ori"
19
 
20
  tokenizer = AutoTokenizer.from_pretrained(
21
  version,
 
26
  kwargs = {
27
  "torch_dtype": torch.half,
28
  }
29
+ model = EvfSamModel.from_pretrained(version, low_cpu_mem_usage=True,
30
+ **kwargs).eval()
31
+ model.to('cuda')
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  @spaces.GPU
35
  @torch.no_grad()
36
+ def pred(image_np, prompt):
37
  original_size_list = [image_np.shape[:2]]
38
 
39
+ image_beit = beit3_preprocess(image_np, 224).to(dtype=model.dtype,
40
+ device=model.device)
41
 
42
  image_sam, resize_shape = sam_preprocess(image_np, model_type=model_type)
43
+ image_sam = image_sam.to(dtype=model.dtype, device=model.device)
 
44
 
45
  input_ids = tokenizer(
46
+ prompt, return_tensors="pt")["input_ids"].to(device=model.device)
47
 
48
  # infer
49
+ pred_mask = model.inference(
50
  image_sam.unsqueeze(0),
51
  image_beit.unsqueeze(0),
52
  input_ids,
 
61
  pred_mask[:, :, None].astype(np.uint8) *
62
  np.array([50, 120, 220]) * 0.5)[pred_mask]
63
 
64
+ return visualization / 255.0, pred_mask.astype(np.float16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  desc = """
 
73
  # desc_title_str = '<div align ="center"><img src="assets/logo.jpg" width="20%"><h3> Early Vision-Language Fusion for Text-Prompted Segment Anything Model</h3></div>'
74
  # desc_link_str = '[![arxiv paper](https://img.shields.io/badge/arXiv-Paper-red)](https://arxiv.org/abs/2406.20076)'
75
 
76
+ demo = gr.Interface(
77
+ fn=pred,
78
+ inputs=[
79
+ gr.components.Image(type="numpy", label="Image", image_mode="RGB"),
80
+ gr.components.Textbox(
81
+ label="Prompt",
82
+ info=
83
+ "Use a phrase or sentence to describe the object you want to segment. Currently we only support English"
84
+ )
85
+ ],
86
+ outputs=[
87
+ gr.components.Image(type="numpy", label="visulization"),
88
+ gr.components.Image(type="numpy", label="mask")
89
+ ],
90
+ examples=[["assets/zebra.jpg", "zebra top left"],
91
+ ["assets/bus.jpg", "bus going to south common"],
92
+ [
93
+ "assets/carrots.jpg",
94
+ "3carrots in center with ice and greenn leaves"
95
+ ]],
96
+ title="📷 EVF-SAM: Referring Expression Segmentation",
97
+ description=desc,
98
+ allow_flagging="never")
99
+ # demo.launch()
100
+ demo.launch()