File size: 9,652 Bytes
82ee3e2 4785a21 82ee3e2 4785a21 82ee3e2 adab5b0 4785a21 82ee3e2 adab5b0 82ee3e2 dd74532 82ee3e2 080fa88 82ee3e2 0f726c9 82ee3e2 080fa88 82ee3e2 080fa88 82ee3e2 080fa88 82ee3e2 dd74532 82ee3e2 adab5b0 82ee3e2 adab5b0 82ee3e2 adab5b0 82ee3e2 ec72b00 adab5b0 ec72b00 adab5b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
import datetime
import cv2
import os
import numpy as np
import torch
# import io
# import cProfile
import csv
# import pstats
import warnings
from memory_profiler import profile
# from pstats import SortKey
from tqdm import tqdm
from torchvision.ops import box_convert
from typing import Tuple
from GroundingDINO.groundingdino.util.inference import load_model, load_image, annotate, preprocess_caption
from GroundingDINO.groundingdino.util.utils import get_phrases_from_posmap
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
from video_utils import mp4_to_png, frame_to_timestamp, vid_stitcher
warnings.filterwarnings("ignore")
def prepare_image(image, transform, device):
image = transform.apply_image(image)
image = torch.as_tensor(image, device=device.device)
return image.permute(2, 0, 1).contiguous()
# @profile
def sam_dino_vid(
vid_path: str,
text_prompt: str,
box_threshold: float = 0.35,
text_threshold: float = 0.25,
fps_processed: int = 1,
scaling_factor: float = 1.0,
video_options: list[str] = ["Bounding boxes"],
config_path: str = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
weights_path: str = "weights/groundingdino_swint_ogc.pth",
device: str = 'cuda',
batch_size: int = 5
) -> (str, str):
""" Args:
Returns:
"""
masks_needed = False
boxes_needed = True
# if masks are selected, load SAM model
if "Bounding boxes" not in video_options:
boxes_needed = False
if "Masks" in video_options:
masks_needed = True
checkpoint = "weights/sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device=device)
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)
# create new dirs and paths for results
filename = os.path.splitext(os.path.basename(vid_path))[0]
results_dir = "../processed/" + filename + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
os.makedirs(results_dir, exist_ok=True)
frames_dir = os.path.join(results_dir, "frames")
os.makedirs(frames_dir, exist_ok=True)
csv_path = os.path.join(results_dir, "detections.csv")
# load the groundingDINO model
gd_model = load_model(config_path, weights_path, device=device)
# process video and create a directory of video frames
fps = mp4_to_png(vid_path, frames_dir, scaling_factor)
# get the frame paths for the images to process
frame_filenames = os.listdir(frames_dir)
frame_paths = [] # list of frame paths to process based on fps_processed
other_paths = [] # list of every frame path in the dir
for i, frame in enumerate(frame_filenames):
if i % fps_processed == 0:
frame_paths.append(os.path.join(frames_dir, frame))
else:
other_paths.append(os.path.join(frames_dir, frame))
# TODO: rename vars to be more clear
# run dino_predict_batch and sam_predict_batch in batches of frames
# write the results to a csv
with open(csv_path, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["Frame", "Timestamp (hh:mm:ss)", "Boxes (cxcywh)", "# Boxes"])
# run groundingDINO in batches
for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"):
batch_paths = frame_paths[i:i+batch_size] # paths for this batch
images_orig = [load_image(img)[0] for img in batch_paths]
image_stack = torch.stack([load_image(img)[1] for img in batch_paths])
boxes_i, logits_i, phrases_i = dino_predict_batch(
model=gd_model,
images=image_stack,
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold
)
annotated_frame_paths = [os.path.join(frames_dir, os.path.basename(frame_path)) for frame_path in batch_paths]
# convert images_orig to rgb from bgr
images_orig_rgb = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in images_orig]
if masks_needed:
# run SAM in batches on boxes from dino
batched_input = []
sam_boxes = []
for image, box in zip(images_orig_rgb, boxes_i):
height, width = image.shape[:2]
# convert the boxes from groundingDINO format to SAM format
box = box * torch.Tensor([width, height, width, height])
box = box_convert(box, in_fmt="cxcywh", out_fmt="xyxy").cuda()
sam_boxes.append(box)
batched_input.append({
"image": prepare_image(image, resize_transform, sam),
"boxes": resize_transform.apply_boxes_torch(box, image.shape[:2]),
"original_size": image.shape[:2]
})
batched_output = sam(batched_input, multimask_output=False)
for i, prediction in enumerate(batched_output):
# write to annotated_frames_dir for stitching
mask = prediction["masks"].cpu().numpy()
box = sam_boxes[i].cpu().numpy()
annotated_frame = plot_sam(images_orig_rgb[i], mask, box, boxes_shown=boxes_needed)
cv2.imwrite(annotated_frame_paths[i], annotated_frame)
elif boxes_needed and not masks_needed:
# get groundingDINO annotated frames
for i, (image, box, logit, phrase) in enumerate(zip(images_orig, boxes_i, logits_i, phrases_i)):
annotated_frame = annotate(image_source=image, boxes=box, logits=logit, phrases=phrase)
cv2.imwrite(annotated_frame_paths[i], annotated_frame)
# write results to csv
# TODO: convert boxes to SAM format for clearer understanding
frame_names = [os.path.basename(frame_path).split(".")[0] for frame_path in batch_paths]
for i, frame in enumerate(frame_names):
writer.writerow([frame, frame_to_timestamp(int(frame[-8:]), fps), boxes_i[i], len(boxes_i[i])])
csvfile.close()
# stitch the frames
save_path = vid_stitcher(frames_dir, output_path=os.path.join(results_dir, "output.mp4"), fps=fps)
print("Results saved to: " + save_path)
return csv_path, save_path
def dino_predict_batch(
model,
images: torch.Tensor,
caption: str,
box_threshold: float,
text_threshold: float,
device: str = "cuda"
) -> Tuple[list[torch.Tensor], list[torch.Tensor], list[list[str]]]:
'''
return:
bboxes_batch: list of tensors of shape (n, 4)
predicts_batch: list of tensors of shape (n,)
phrases_batch: list of list of strings of shape (n,)
'''
caption = preprocess_caption(caption=caption)
model = model.to(device)
image = images.to(device)
with torch.no_grad():
outputs = model(image, captions=[caption for _ in range(len(images))])
prediction_logits = outputs["pred_logits"].cpu().sigmoid() # prediction_logits.shape = (num_batch, nq, 256)
prediction_boxes = outputs["pred_boxes"].cpu() # prediction_boxes.shape = (num_batch, nq, 4)
mask = prediction_logits.max(dim=2)[0] > box_threshold # mask: torch.Size([num_batch, 256])
bboxes_batch = []
predicts_batch = []
phrases_batch = [] # list of lists
tokenizer = model.tokenizer
tokenized = tokenizer(caption)
for i in range(prediction_logits.shape[0]):
logits = prediction_logits[i][mask[i]] # logits.shape = (n, 256)
phrases = [
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
for logit # logit is a tensor of shape (256,) torch.Size([256])
in logits # torch.Size([7, 256])
]
boxes = prediction_boxes[i][mask[i]] # boxes.shape = (n, 4)
phrases_batch.append(phrases)
bboxes_batch.append(boxes)
predicts_batch.append(logits.max(dim=1)[0])
return bboxes_batch, predicts_batch, phrases_batch
def plot_sam(
image: np.ndarray,
masks: list[np.ndarray],
boxes: np.ndarray,
boxes_shown: bool = True,
masks_shown: bool = True,
) -> np.ndarray:
"""
Plot image with masks and/or boxes.
"""
# Use cv2 to plot the boxes and masks if they exist
if boxes_shown:
for box in boxes:
# red bbox
cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), 2)
if masks_shown:
# blue mask
color = np.array([255, 144, 30])
color = color.astype(np.uint8)
for mask in masks:
# turn the mask into a colored mask
h, w = mask.shape[-2:]
mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
image = cv2.addWeighted(image, 1, mask, 0.5, 0)
return image
# if __name__ == '__main__':
# def run_sam_dino_vid():
# sam_dino_vid("baboon_15s.mp4", "baboon", box_threshold=0.3, text_threshold=0.3, fps_processed=30, video_options=['Bounding boxes', 'Masks'])
# start_time = datetime.datetime.now()
# stats = run_sam_dino_vid()
# print("elapsed: " + str(datetime.datetime.now() - start_time))
|