Spaces:
Runtime error
Runtime error
import cv2 | |
import numpy as np | |
import supervision as sv | |
import argparse | |
import torch | |
import torchvision | |
from groundingdino.util.inference import Model | |
from segment_anything import SamPredictor | |
from MobileSAM.setup_mobile_sam import setup_model | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--MOBILE_SAM_CHECKPOINT_PATH", type=str, default="./EfficientSAM/mobile_sam.pt", help="model" | |
) | |
parser.add_argument( | |
"--SOURCE_IMAGE_PATH", type=str, default="./assets/demo2.jpg", help="path to image file" | |
) | |
parser.add_argument( | |
"--CAPTION", type=str, default="The running dog", help="text prompt for GroundingDINO" | |
) | |
parser.add_argument( | |
"--OUT_FILE_BOX", type=str, default="groundingdino_annotated_image.jpg", help="the output filename" | |
) | |
parser.add_argument( | |
"--OUT_FILE_SEG", type=str, default="grounded_mobile_sam_annotated_image.jpg", help="the output filename" | |
) | |
parser.add_argument( | |
"--OUT_FILE_BIN_MASK", type=str, default="grounded_mobile_sam_bin_mask.jpg", help="the output filename" | |
) | |
parser.add_argument("--BOX_THRESHOLD", type=float, default=0.25, help="") | |
parser.add_argument("--TEXT_THRESHOLD", type=float, default=0.25, help="") | |
parser.add_argument("--NMS_THRESHOLD", type=float, default=0.8, help="") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
parser.add_argument( | |
"--DEVICE", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu" | |
) | |
return parser.parse_args() | |
def main(args): | |
DEVICE = args.DEVICE | |
# GroundingDINO config and checkpoint | |
GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" | |
GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth" | |
# Building GroundingDINO inference model | |
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH) | |
# Building MobileSAM predictor | |
MOBILE_SAM_CHECKPOINT_PATH = args.MOBILE_SAM_CHECKPOINT_PATH | |
checkpoint = torch.load(MOBILE_SAM_CHECKPOINT_PATH) | |
mobile_sam = setup_model() | |
mobile_sam.load_state_dict(checkpoint, strict=True) | |
mobile_sam.to(device=DEVICE) | |
sam_predictor = SamPredictor(mobile_sam) | |
# Predict classes and hyper-param for GroundingDINO | |
SOURCE_IMAGE_PATH = args.SOURCE_IMAGE_PATH | |
CLASSES = [args.CAPTION] | |
BOX_THRESHOLD = args.BOX_THRESHOLD | |
TEXT_THRESHOLD = args.TEXT_THRESHOLD | |
NMS_THRESHOLD = args.NMS_THRESHOLD | |
# load image | |
image = cv2.imread(SOURCE_IMAGE_PATH) | |
# detect objects | |
detections = grounding_dino_model.predict_with_classes( | |
image=image, | |
classes=CLASSES, | |
box_threshold=BOX_THRESHOLD, | |
text_threshold=TEXT_THRESHOLD | |
) | |
# annotate image with detections | |
box_annotator = sv.BoxAnnotator() | |
labels = [ | |
f"{CLASSES[class_id]} {confidence:0.2f}" | |
for _, _, confidence, class_id, _ | |
in detections] | |
annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels) | |
# save the annotated grounding dino image | |
cv2.imwrite(args.OUT_FILE_BOX, annotated_frame) | |
# NMS post process | |
print(f"Before NMS: {len(detections.xyxy)} boxes") | |
nms_idx = torchvision.ops.nms( | |
torch.from_numpy(detections.xyxy), | |
torch.from_numpy(detections.confidence), | |
NMS_THRESHOLD | |
).numpy().tolist() | |
detections.xyxy = detections.xyxy[nms_idx] | |
detections.confidence = detections.confidence[nms_idx] | |
detections.class_id = detections.class_id[nms_idx] | |
print(f"After NMS: {len(detections.xyxy)} boxes") | |
# Prompting SAM with detected boxes | |
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: | |
sam_predictor.set_image(image) | |
result_masks = [] | |
for box in xyxy: | |
masks, scores, logits = sam_predictor.predict( | |
box=box, | |
multimask_output=True | |
) | |
index = np.argmax(scores) | |
result_masks.append(masks[index]) | |
return np.array(result_masks) | |
# convert detections to masks | |
detections.mask = segment( | |
sam_predictor=sam_predictor, | |
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB), | |
xyxy=detections.xyxy | |
) | |
binary_mask = detections.mask[0].astype(np.uint8)*255 | |
cv2.imwrite(args.OUT_FILE_BIN_MASK, binary_mask) | |
# annotate image with detections | |
box_annotator = sv.BoxAnnotator() | |
mask_annotator = sv.MaskAnnotator() | |
labels = [ | |
f"{CLASSES[class_id]} {confidence:0.2f}" | |
for _, _, confidence, class_id, _ | |
in detections] | |
annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections) | |
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels) | |
# save the annotated grounded-sam image | |
cv2.imwrite(args.OUT_FILE_SEG, annotated_image) | |
if __name__ == "__main__": | |
args = parse_args() | |
main(args) | |