updated inference, testing batch size
Browse files- app.py +1 -9
- dino_sam.py +10 -7
app.py
CHANGED
@@ -7,8 +7,6 @@ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:5000'
|
|
7 |
subprocess.run(['pip', 'install', '-e', 'GroundingDINO'])
|
8 |
sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
|
9 |
sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
|
10 |
-
# os.system("wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth")
|
11 |
-
# os.system("wget https://huggingface.co/spaces/mrtlive/segment-anything-model/resolve/main/sam_vit_h_4b8939.pth")
|
12 |
|
13 |
import gradio as gr
|
14 |
from dino_sam import sam_dino_vid
|
@@ -43,12 +41,6 @@ with gr.Blocks() as demo:
|
|
43 |
"""
|
44 |
)
|
45 |
|
46 |
-
gr.HTML(
|
47 |
-
"""
|
48 |
-
<p="left">
|
49 |
-
The csv contains frame numbers and timestamps, bounding box coordinates, and number of detections per frame.</p>
|
50 |
-
"""
|
51 |
-
)
|
52 |
with gr.Row():
|
53 |
with gr.Column():
|
54 |
input = gr.Video(label="Input Video", interactive=True)
|
@@ -74,7 +66,7 @@ with gr.Blocks() as demo:
|
|
74 |
step=1)
|
75 |
video_options = gr.CheckboxGroup(choices=["Bounding boxes", "Masks"],
|
76 |
label="Video Output Options",
|
77 |
-
info="Select the options to display in the output video.",
|
78 |
value=["Bounding boxes"],
|
79 |
interactive=True)
|
80 |
|
|
|
7 |
subprocess.run(['pip', 'install', '-e', 'GroundingDINO'])
|
8 |
sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
|
9 |
sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
|
|
|
|
|
10 |
|
11 |
import gradio as gr
|
12 |
from dino_sam import sam_dino_vid
|
|
|
41 |
"""
|
42 |
)
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
with gr.Row():
|
45 |
with gr.Column():
|
46 |
input = gr.Video(label="Input Video", interactive=True)
|
|
|
66 |
step=1)
|
67 |
video_options = gr.CheckboxGroup(choices=["Bounding boxes", "Masks"],
|
68 |
label="Video Output Options",
|
69 |
+
info="Select the options to display in the output video. Note: if masks are selected, runtime will increase.",
|
70 |
value=["Bounding boxes"],
|
71 |
interactive=True)
|
72 |
|
dino_sam.py
CHANGED
@@ -8,7 +8,7 @@ import torch
|
|
8 |
import csv
|
9 |
# import pstats
|
10 |
import warnings
|
11 |
-
|
12 |
# from pstats import SortKey
|
13 |
from tqdm import tqdm
|
14 |
from torchvision.ops import box_convert
|
@@ -26,6 +26,7 @@ def prepare_image(image, transform, device):
|
|
26 |
image = torch.as_tensor(image, device=device.device)
|
27 |
return image.permute(2, 0, 1).contiguous()
|
28 |
|
|
|
29 |
def sam_dino_vid(
|
30 |
vid_path: str,
|
31 |
text_prompt: str,
|
@@ -36,7 +37,7 @@ def sam_dino_vid(
|
|
36 |
config_path: str = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
|
37 |
weights_path: str = "weights/groundingdino_swint_ogc.pth",
|
38 |
device: str = 'cuda',
|
39 |
-
batch_size: int =
|
40 |
) -> (str, str):
|
41 |
""" Args:
|
42 |
Returns:
|
@@ -101,13 +102,13 @@ def sam_dino_vid(
|
|
101 |
|
102 |
annotated_frame_paths = [os.path.join(frames_dir, os.path.basename(frame_path)) for frame_path in batch_paths]
|
103 |
# convert images_orig to rgb from bgr
|
104 |
-
|
105 |
|
106 |
if masks_needed:
|
107 |
# run SAM in batches on boxes from dino
|
108 |
batched_input = []
|
109 |
sam_boxes = []
|
110 |
-
for image, box in zip(
|
111 |
height, width = image.shape[:2]
|
112 |
# convert the boxes from groundingDINO format to SAM format
|
113 |
box = box * torch.Tensor([width, height, width, height])
|
@@ -123,7 +124,7 @@ def sam_dino_vid(
|
|
123 |
# write to annotated_frames_dir for stitching
|
124 |
mask = prediction["masks"].cpu().numpy()
|
125 |
box = sam_boxes[i].cpu().numpy()
|
126 |
-
annotated_frame = plot_sam(
|
127 |
cv2.imwrite(annotated_frame_paths[i], annotated_frame)
|
128 |
|
129 |
elif boxes_needed and not masks_needed:
|
@@ -215,6 +216,8 @@ def plot_sam(
|
|
215 |
return image
|
216 |
|
217 |
# if __name__ == '__main__':
|
|
|
|
|
218 |
# start_time = datetime.datetime.now()
|
219 |
-
#
|
220 |
-
# print("elapsed: " + str(datetime.datetime.now() - start_time))
|
|
|
8 |
import csv
|
9 |
# import pstats
|
10 |
import warnings
|
11 |
+
from memory_profiler import profile
|
12 |
# from pstats import SortKey
|
13 |
from tqdm import tqdm
|
14 |
from torchvision.ops import box_convert
|
|
|
26 |
image = torch.as_tensor(image, device=device.device)
|
27 |
return image.permute(2, 0, 1).contiguous()
|
28 |
|
29 |
+
# @profile
|
30 |
def sam_dino_vid(
|
31 |
vid_path: str,
|
32 |
text_prompt: str,
|
|
|
37 |
config_path: str = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
|
38 |
weights_path: str = "weights/groundingdino_swint_ogc.pth",
|
39 |
device: str = 'cuda',
|
40 |
+
batch_size: int = 10
|
41 |
) -> (str, str):
|
42 |
""" Args:
|
43 |
Returns:
|
|
|
102 |
|
103 |
annotated_frame_paths = [os.path.join(frames_dir, os.path.basename(frame_path)) for frame_path in batch_paths]
|
104 |
# convert images_orig to rgb from bgr
|
105 |
+
images_orig_rgb = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in images_orig]
|
106 |
|
107 |
if masks_needed:
|
108 |
# run SAM in batches on boxes from dino
|
109 |
batched_input = []
|
110 |
sam_boxes = []
|
111 |
+
for image, box in zip(images_orig_rgb, boxes_i):
|
112 |
height, width = image.shape[:2]
|
113 |
# convert the boxes from groundingDINO format to SAM format
|
114 |
box = box * torch.Tensor([width, height, width, height])
|
|
|
124 |
# write to annotated_frames_dir for stitching
|
125 |
mask = prediction["masks"].cpu().numpy()
|
126 |
box = sam_boxes[i].cpu().numpy()
|
127 |
+
annotated_frame = plot_sam(images_orig_rgb[i], mask, box, boxes_shown=boxes_needed)
|
128 |
cv2.imwrite(annotated_frame_paths[i], annotated_frame)
|
129 |
|
130 |
elif boxes_needed and not masks_needed:
|
|
|
216 |
return image
|
217 |
|
218 |
# if __name__ == '__main__':
|
219 |
+
# def run_sam_dino_vid():
|
220 |
+
# sam_dino_vid("baboon_15s.mp4", "baboon", box_threshold=0.3, text_threshold=0.3, fps_processed=30, video_options=['Bounding boxes', 'Masks'])
|
221 |
# start_time = datetime.datetime.now()
|
222 |
+
# stats = run_sam_dino_vid()
|
223 |
+
# print("elapsed: " + str(datetime.datetime.now() - start_time))
|