Spaces:
Runtime error
Runtime error
"""Run a demo of the CaR model on a single image.""" | |
import numpy as np | |
import os | |
import argparse | |
from functools import reduce | |
import PIL.Image as Image | |
import torch | |
from modeling.model import CaR | |
from utils.utils import Config, load_yaml | |
import matplotlib.pyplot as plt | |
import colorsys | |
from modeling.post_process.post_process import ( | |
match_masks, | |
generate_masks_from_sam, | |
) | |
from sam.sam import SAMPipeline | |
from sam.utils import build_sam_config | |
import random | |
import time | |
def generate_distinct_colors(n): | |
colors = [] | |
# generate a random number from 0 to 1 | |
random_color_bias = random.random() | |
for i in range(n): | |
hue = float(i) / n | |
hue += random_color_bias | |
hue = hue % 1.0 | |
rgb = colorsys.hsv_to_rgb(hue, 1.0, 1.0) | |
# Convert RGB values from [0, 1] range to [0, 255] | |
colors.append(tuple(int(val * 255) for val in rgb)) | |
return colors | |
def overlap_masks(masks): | |
""" | |
Overlap masks to generate a single mask for visualization. | |
Parameters: | |
- masks: list of np.arrays of shape (H, W) representing binary masks | |
for each class. | |
Returns: | |
- overlap_mask: list of np.array of shape (H, W) that have no overlaps | |
""" | |
overlap_mask = torch.zeros_like(masks[0]) | |
for mask_idx, mask in enumerate(masks): | |
overlap_mask[mask > 0] = mask_idx + 1 | |
clean_masks = [ | |
overlap_mask == mask_idx + 1 for mask_idx in range(len(masks)) | |
] | |
clean_masks = torch.stack(clean_masks, dim=0) | |
return clean_masks | |
def visualize_segmentation( | |
image, masks, class_names, alpha=0.45, y_list=None, x_list=None | |
): | |
""" | |
Visualize segmentation masks on an image. | |
Parameters: | |
- image: np.array of shape (H, W, 3) representing the RGB image | |
- masks: list of np.arrays of shape (H, W) representing binary masks | |
for each class. | |
- class_names: list of strings representing names of each class | |
- alpha: float, transparency level of masks on the image | |
Returns: | |
- visualization: plt.figure object | |
""" | |
# Create a figure and axis | |
fig, ax = plt.subplots(1, figsize=(12, 9)) | |
# Display the image | |
# ax.imshow(image) | |
# Generate distinct colors for each mask | |
final_mask = np.zeros( | |
(masks.shape[1], masks.shape[2], 3), dtype=np.float32 | |
) | |
colors = generate_distinct_colors(len(class_names)) | |
idx = 0 | |
for mask, color, class_name in zip(masks, colors, class_names): | |
# Overlay the mask | |
final_mask += np.dstack([mask * c for c in color]) | |
# Find a representative point (e.g., centroid) for placing the label | |
if y_list is None or x_list is None: | |
y, x = np.argwhere(mask).mean(axis=0) | |
else: | |
y, x = y_list[idx], x_list[idx] | |
ax.text( | |
x, | |
y, | |
class_name, | |
color="white", | |
fontsize=36, | |
va="center", | |
ha="center", | |
bbox=dict(facecolor="black", alpha=0.7, edgecolor="none"), | |
) | |
idx += 1 | |
final_image = image * (1 - alpha) + final_mask * alpha | |
final_image = final_image.astype(np.uint8) | |
ax.imshow(final_image) | |
# Remove axis ticks and labels | |
ax.axis("off") | |
return fig | |
def get_sam_masks(config, image_path, masks, img_sam=None, pipeline=None): | |
print("generating sam masks online") | |
mask_tensor, mask_list = generate_masks_from_sam( | |
image_path, | |
save_path="./", | |
pipeline=pipeline, | |
img_sam=img_sam, | |
visualize=False, | |
) | |
mask_tensor = mask_tensor.to(masks.device) | |
# only conduct sam on masks that is not all zero | |
attn_map, mask_ids = [], [] | |
for mask_id, mask in enumerate(masks): | |
if torch.sum(mask) > 0: | |
attn_map.append(mask.unsqueeze(0)) | |
mask_ids.append(mask_id) | |
matched_masks = [ | |
match_masks( | |
mask_tensor, | |
attn, | |
mask_list, | |
iom_thres=config.car.iom_thres, | |
min_pred_threshold=config.sam.min_pred_threshold, | |
) | |
for attn in attn_map | |
] | |
for matched_mask, mask_id in zip(matched_masks, mask_ids): | |
sam_masks = np.array([item["segmentation"] for item in matched_mask]) | |
sam_mask = np.any(sam_masks, axis=0) | |
masks[mask_id] = torch.from_numpy(sam_mask).to(masks.device) | |
return masks | |
def load_sam(config, sam_device): | |
sam_checkpoint, model_type = build_sam_config(config) | |
pipelines = SAMPipeline( | |
sam_checkpoint, | |
model_type, | |
device=sam_device, | |
points_per_side=config.sam.points_per_side, | |
pred_iou_thresh=config.sam.pred_iou_thresh, | |
stability_score_thresh=config.sam.stability_score_thresh, | |
box_nms_thresh=config.sam.box_nms_thresh, | |
) | |
return pipelines | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser("CaR") | |
# default arguments | |
# additional arguments | |
parser.add_argument( | |
"--output_path", type=str, default="", help="path to save outputs" | |
) | |
parser.add_argument( | |
"--cfg-path", | |
default="configs/voc_test.yaml", | |
help="path to configuration file.", | |
) | |
args = parser.parse_args() | |
cfg = Config(**load_yaml(args.cfg_path)) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# device = 'cpu' | |
folder_name = reduce( | |
lambda x, y: x.replace(" ", "_") + "_" + y, cfg.image_caption | |
) | |
if len(folder_name) > 20: | |
folder_name = folder_name[:20] | |
car_model = CaR( | |
cfg, visualize=True, seg_mode=cfg.test.seg_mode, device=device | |
) | |
sam_pipeline = load_sam(cfg, device) | |
img = Image.open(cfg.image_path).convert("RGB") | |
import pdb; pdb.set_trace() | |
# resize image by dividing 2 if the size is larger than 1000 | |
if img.size[0] > 1000: | |
img = img.resize((img.size[0] // 3, img.size[1] // 3)) | |
label_space = cfg.image_caption | |
pseudo_masks, scores, _ = car_model(img, label_space) | |
if not cfg.test.use_pseudo: | |
t1 = time.time() | |
pseudo_masks = get_sam_masks( | |
cfg, | |
cfg.image_path, | |
pseudo_masks, | |
img_sam=np.array(img), | |
pipeline=sam_pipeline, | |
) | |
pseudo_masks = overlap_masks(pseudo_masks) | |
t2 = time.time() | |
print(f"sam time: {t2 - t1}") | |
# visualize segmentation masks | |
demo_fig = visualize_segmentation( | |
np.array(img), | |
pseudo_masks.detach().cpu().numpy(), | |
label_space, | |
) | |
save_path = f"vis_results/{folder_name}" | |
if not os.path.exists(save_path): | |
os.makedirs(save_path) | |
demo_fig.savefig(os.path.join(save_path, "demo.png"), bbox_inches="tight") | |
print(f"results saved to {save_path}.") | |