Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2024 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Copyright 2024 The Google Research Authors. | |
# This file is based on the SAM (Segment Anything) and HQ-SAM. | |
# | |
# https://github.com/facebookresearch/segment-anything | |
# https://github.com/SysCV/sam-hq/tree/main | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""SAM Utilities.""" | |
# pylint: disable=all | |
# pylint: disable=g-importing-member | |
import json | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from scipy.spatial.distance import cdist | |
def show_mask(mask, ax, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
def show_points(coords, labels, ax, marker_size=375): | |
pos_points = coords[labels == 1] | |
neg_points = coords[labels == 0] | |
ax.scatter( | |
pos_points[:, 0], | |
pos_points[:, 1], | |
color='green', | |
marker='*', | |
s=marker_size, | |
edgecolor='white', | |
linewidth=1.25, | |
) | |
ax.scatter( | |
neg_points[:, 0], | |
neg_points[:, 1], | |
color='red', | |
marker='*', | |
s=marker_size, | |
edgecolor='white', | |
linewidth=1.25, | |
) | |
def show_box(box, ax): | |
x0, y0, x1, y1 = box | |
w, h = x1 - x0, y1 - y0 | |
ax.add_patch( | |
plt.Rectangle( | |
(x0, y0), w, h, edgecolor='red', facecolor=(0, 0, 0, 0), lw=2 | |
) | |
) | |
def show_anns(anns): | |
if len(anns) == 0: | |
return | |
for index, dictionary in enumerate(anns): | |
dictionary['id'] = index | |
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) | |
ax = plt.gca() | |
ax.set_autoscale_on(False) | |
# polygons = [] | |
# color = [] | |
for ann in sorted_anns: | |
m = ann['segmentation'] | |
img = np.ones((m.shape[0], m.shape[1], 3)) | |
color_mask = np.random.random((1, 3)).tolist()[0] | |
for i in range(3): | |
img[:, :, i] = color_mask[i] | |
ax.imshow(np.dstack((img, m * 0.35))) | |
# Get the centroid of the mask | |
mask_y, mask_x = np.nonzero(m) | |
centroid_x, centroid_y = np.mean(mask_x), np.mean(mask_y) | |
# Display the mask ID | |
mask_id = ann['id'] | |
ax.text( | |
centroid_x, | |
centroid_y, | |
str(mask_id), | |
color='black', | |
fontsize=48, | |
weight='bold', | |
) | |
# Turn CAM result to SAM prompt | |
def aggregate_RGB_channel(activation_mask, how='max'): | |
B, C, H, W = activation_mask.shape | |
if how == 'max': | |
res_activation_mask = np.amax(activation_mask, axis=1, keepdims=True) | |
elif how == 'avr': | |
res_activation_mask = np.mean(activation_mask, axis=1, keepdims=True) | |
res_activation_mask = res_activation_mask.reshape(B, 1, H * W) | |
res_activation_mask = np.squeeze(res_activation_mask, axis=1) | |
return res_activation_mask | |
def find_k_points(arr, k, order='max', how_filter='median'): | |
arr = arr.squeeze(0) | |
flat_indices = np.argpartition(arr.flatten(), -k)[-k:] | |
unravel_topk_idx = np.unravel_index(flat_indices, arr.shape) | |
topk_indices = np.array(unravel_topk_idx).transpose()[:, ::-1] | |
# print(topk_indices.shape) | |
if how_filter == 'random': | |
random_rows = np.random.choice( | |
topk_indices.shape[0], size=int(round(k / 16)), replace=False | |
) | |
topk_indices = topk_indices[random_rows] | |
elif how_filter == 'median': | |
distances = cdist(topk_indices, topk_indices) | |
distances = np.sum(distances, axis=1) | |
median_distance = np.median(distances) | |
filtered_idx = [ | |
i for i in range(len(distances)) if distances[i] < median_distance | |
] | |
topk_indices = topk_indices[filtered_idx] | |
return topk_indices | |
def max_sum_submatrix(matrix): | |
matrix = np.array(matrix) | |
H, W = matrix.shape | |
# Preprocess cumulative sums for rows | |
matrix[:, 1:] += matrix[:, :-1] | |
max_sum = float('-inf') | |
max_rect = (0, 0, 0, 0) # (top, left, bottom, right) | |
for left in range(W): | |
for right in range(left, W): | |
# Apply 1D Kadane's algorithm for the current pair of columns | |
column_sum = matrix[:, right] - (matrix[:, left - 1] if left > 0 else 0) | |
max_ending_here = max_so_far = column_sum[0] | |
start, end = 0, 0 | |
for i in range(1, H): | |
val = column_sum[i] | |
if max_ending_here > 0: | |
max_ending_here += val | |
else: | |
max_ending_here = val | |
start = i | |
if max_ending_here > max_so_far: | |
max_so_far = max_ending_here | |
end = i | |
if max_so_far > max_sum: | |
max_sum = max_so_far | |
max_rect = (start, left, end, right) | |
return max_sum, max_rect | |
def CAM2SAMClick(activation_map, k=5, order='max', how_filter='median'): | |
# activation_map = aggregate_RGB_channel(activation_map) | |
H, W, C = activation_map.shape | |
activation_map = activation_map.reshape((1, 1, H, W)) | |
coords = [] | |
for nrow in range(activation_map.shape[0]): | |
coord = find_k_points(activation_map[nrow], k, order, how_filter) | |
coords.append(coord) | |
return coords | |
def CAM2SAMBox(activation_map): | |
# print(activation_map.shape) | |
# activation_map = aggregate_RGB_channel(activation_map) | |
H, W, C = activation_map.shape | |
activation_map = activation_map.reshape((1, H, W)) | |
box_coordinates = [] | |
for nrow in range(activation_map.shape[0]): | |
# print(activation_map[nrow].shape) | |
arr = activation_map[nrow] | |
norm_arr = 2 * ((arr - np.min(arr)) / (np.max(arr) - np.min(arr))) - 1 | |
# print(norm_arr.shape) | |
_, box_coordinate = max_sum_submatrix(norm_arr) | |
box_coordinates.append(box_coordinate) | |
return box_coordinates | |
# Visualize | |
def visualize_attention(arr, filename): | |
# Create a figure and axes object | |
fig, ax = plt.subplots() | |
# Display the array as an image | |
im = ax.imshow(arr) | |
# Add a colorbar | |
ax.figure.colorbar(im, ax=ax) | |
# cbar = ax.figure.colorbar(im, ax=ax) | |
# Save the figure as a PNG file | |
fig.savefig(filename) | |
# Build config | |
# def build_sam_config(config_path): | |
# with open(config_path, 'r') as infile: | |
# config = json.load(infile) | |
# sam_checkpoint = config['model']['sam_checkpoint'] | |
# model_type = config['model']['model_type'] | |
# return sam_checkpoint, model_type | |
def build_sam_config(config): | |
sam_checkpoint = config.sam.sam_checkpoint | |
model_type = config.sam.model_type | |
return sam_checkpoint, model_type |