Primate-Detection-GPU / dino_sam.py
annading's picture
first commit
82ee3e2
raw
history blame
9.44 kB
import datetime
import cv2
import os
import numpy as np
import torch
import io
import cProfile
import csv
import pstats
import warnings
from memory_profiler import profile
from pstats import SortKey
from tqdm import tqdm
from torchvision.ops import box_convert
from typing import Tuple
from GroundingDINO.groundingdino.util.inference import load_model, load_image, annotate, preprocess_caption
from GroundingDINO.groundingdino.util.utils import get_phrases_from_posmap
from segment_anything import sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
from video_utils import mp4_to_png, frame_to_timestamp, vid_stitcher
warnings.filterwarnings("ignore")
def prepare_image(image, transform, device):
image = transform.apply_image(image)
image = torch.as_tensor(image, device=device.device)
return image.permute(2, 0, 1).contiguous()
def sam_dino_vid(
vid_path: str,
text_prompt: str,
box_threshold: float = 0.35,
text_threshold: float = 0.25,
fps_processed: int = 1,
video_options: list[str] = ["Bounding boxes"],
config_path: str = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
weights_path: str = "groundingdino_swint_ogc.pth",
device: str = 'cuda',
batch_size: int = 10
) -> (str, str):
""" Args:
Returns:
"""
masks_needed = False
boxes_needed = True
# if masks are selected, load SAM model
if "Bounding boxes" not in video_options:
boxes_needed = False
if "Masks" in video_options:
masks_needed = True
checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device=device)
resize_transform = ResizeLongestSide(sam.image_encoder.img_size)
# create new dirs and paths for results
filename = os.path.splitext(os.path.basename(vid_path))[0]
results_dir = "../processed/" + filename + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
os.mkdir(results_dir)
frames_dir = os.path.join(results_dir, "frames")
os.mkdir(frames_dir)
csv_path = os.path.join(results_dir, "detections.csv")
# load the groundingDINO model
gd_model = load_model(config_path, weights_path, device=device)
# process video and create a directory of video frames
fps = mp4_to_png(vid_path, frames_dir)
# get the frame paths for the images to process
frame_filenames = os.listdir(frames_dir)
frame_paths = [] # list of frame paths to process based on fps_processed
other_paths = [] # list of every frame path in the dir
for i, frame in enumerate(frame_filenames):
if i % fps_processed == 0:
frame_paths.append(os.path.join(frames_dir, frame))
else:
other_paths.append(os.path.join(frames_dir, frame))
# TODO: rename vars to be more clear
# run dino_predict_batch and sam_predict_batch in batches of frames
# write the results to a csv
with open(csv_path, 'w', newline='') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["Frame", "Timestamp (hh:mm:ss)", "Boxes (cxcywh)", "# Boxes"])
# run groundingDINO in batches
for i in tqdm(range(0, len(frame_paths), batch_size), desc="Running batches"):
batch_paths = frame_paths[i:i+batch_size] # paths for this batch
images_orig = [load_image(img)[0] for img in batch_paths]
image_stack = torch.stack([load_image(img)[1] for img in batch_paths])
boxes_i, logits_i, phrases_i = dino_predict_batch(
model=gd_model,
images=image_stack,
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold
)
annotated_frame_paths = [os.path.join(frames_dir, os.path.basename(frame_path)) for frame_path in batch_paths]
# convert images_orig to rgb from bgr
images_orig = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in images_orig]
if masks_needed:
# run SAM in batches on boxes from dino
batched_input = []
sam_boxes = []
for image, box in zip(images_orig, boxes_i):
height, width = image.shape[:2]
# convert the boxes from groundingDINO format to SAM format
box = box * torch.Tensor([width, height, width, height])
box = box_convert(box, in_fmt="cxcywh", out_fmt="xyxy").cuda()
sam_boxes.append(box)
batched_input.append({
"image": prepare_image(image, resize_transform, sam),
"boxes": resize_transform.apply_boxes_torch(box, image.shape[:2]),
"original_size": image.shape[:2]
})
batched_output = sam(batched_input, multimask_output=False)
for i, prediction in enumerate(batched_output):
# write to annotated_frames_dir for stitching
mask = prediction["masks"].cpu().numpy()
box = sam_boxes[i].cpu().numpy()
annotated_frame = plot_sam(images_orig[i], mask, box, boxes_shown=boxes_needed)
cv2.imwrite(annotated_frame_paths[i], annotated_frame)
elif boxes_needed and not masks_needed:
# get groundingDINO annotated frames
for i, (image, box, logit, phrase) in enumerate(zip(images_orig, boxes_i, logits_i, phrases_i)):
annotated_frame = annotate(image_source=image, boxes=box, logits=logit, phrases=phrase)
cv2.imwrite(annotated_frame_paths[i], annotated_frame)
# write results to csv
# TODO: convert boxes to SAM format for clearer understanding
frame_names = [os.path.basename(frame_path).split(".")[0] for frame_path in batch_paths]
for i, frame in enumerate(frame_names):
writer.writerow([frame, frame_to_timestamp(int(frame[-8:]), fps), boxes_i[i], len(boxes_i[i])])
csvfile.close()
# stitch the frames
save_path = vid_stitcher(frames_dir, output_path=os.path.join(results_dir, "output.mp4"), fps=fps)
print("Results saved to: " + save_path)
return csv_path, save_path
def dino_predict_batch(
model,
images: torch.Tensor,
caption: str,
box_threshold: float,
text_threshold: float,
device: str = "cuda"
) -> Tuple[list[torch.Tensor], list[torch.Tensor], list[list[str]]]:
'''
return:
bboxes_batch: list of tensors of shape (n, 4)
predicts_batch: list of tensors of shape (n,)
phrases_batch: list of list of strings of shape (n,)
'''
caption = preprocess_caption(caption=caption)
model = model.to(device)
image = images.to(device)
with torch.no_grad():
outputs = model(image, captions=[caption for _ in range(len(images))])
prediction_logits = outputs["pred_logits"].cpu().sigmoid() # prediction_logits.shape = (num_batch, nq, 256)
prediction_boxes = outputs["pred_boxes"].cpu() # prediction_boxes.shape = (num_batch, nq, 4)
mask = prediction_logits.max(dim=2)[0] > box_threshold # mask: torch.Size([num_batch, 256])
bboxes_batch = []
predicts_batch = []
phrases_batch = [] # list of lists
tokenizer = model.tokenizer
tokenized = tokenizer(caption)
for i in range(prediction_logits.shape[0]):
logits = prediction_logits[i][mask[i]] # logits.shape = (n, 256)
phrases = [
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
for logit # logit is a tensor of shape (256,) torch.Size([256])
in logits # torch.Size([7, 256])
]
boxes = prediction_boxes[i][mask[i]] # boxes.shape = (n, 4)
phrases_batch.append(phrases)
bboxes_batch.append(boxes)
predicts_batch.append(logits.max(dim=1)[0])
return bboxes_batch, predicts_batch, phrases_batch
def plot_sam(
image: np.ndarray,
masks: list[np.ndarray],
boxes: np.ndarray,
boxes_shown: bool = True,
masks_shown: bool = True,
) -> np.ndarray:
"""
Plot image with masks and/or boxes.
"""
# Use cv2 to plot the boxes and masks if they exist
if boxes_shown:
for box in boxes:
# red bbox
cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), 2)
if masks_shown:
# blue mask
color = np.array([255, 144, 30])
color = color.astype(np.uint8)
for mask in masks:
# turn the mask into a colored mask
h, w = mask.shape[-2:]
mask = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
image = cv2.addWeighted(image, 1, mask, 0.5, 0)
return image
if __name__ == '__main__':
start_time = datetime.datetime.now()
sam_dino_vid("baboon_15s.mp4", "baboon", box_threshold=0.3, text_threshold=0.3, fps_processed=30, video_options=['Bounding boxes', 'Masks'])
print("elapsed: " + str(datetime.datetime.now() - start_time))