Spaces:
Sleeping
Sleeping
from models import MaskDecoderHQ | |
from ppc_decoder import sam_decoder_reg | |
from segment_anything import sam_model_registry | |
import torch.nn as nn | |
import torch | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
from utils.transforms import ResizeLongestSide | |
from typing import List | |
trans = ResizeLongestSide(target_length=1024) | |
def save_prob_visualization(prob, filename="prob_visualization.png"): | |
""" | |
可视化 1xwxh 的概率图并使用 plt.imshow 保存到本地 | |
:param prob: 形状为 1xwxh 的 tensor | |
:param filename: 保存的文件名,默认为 'prob_visualization.png' | |
""" | |
# 将 prob 转换为 numpy 数组 | |
prob_np = prob.squeeze(0).squeeze(0).numpy() # 从 1xwxh 转为 wxh | |
# 使用 plt.imshow 可视化 | |
plt.imshow(prob_np) | |
# , cmap='gray', vmin=0, vmax=1) # cmap='gray' 确保图像以灰度显示 | |
plt.axis('off') # 关闭坐标轴 | |
# 保存图像 | |
plt.savefig(filename, bbox_inches='tight', pad_inches=0) | |
plt.close() | |
print(f"Probability map saved as {filename}") | |
def pad_to_square(x: torch.Tensor, target_size: int) -> torch.Tensor: | |
"""Pad the input tensor to a square shape with the specified target size.""" | |
# Get the current height and width of the image | |
h, w = x.shape[-2:] | |
# Calculate padding for height and width | |
padh = target_size - h | |
padw = target_size - w | |
# Pad the tensor to the target size | |
x = F.pad(x, (0, padw, 0, padh)) | |
return x | |
def remove_none_values(input_dict): | |
""" | |
Remove all items with None as their value from the dictionary. | |
Args: | |
input_dict (dict): The dictionary from which to remove None values. | |
Returns: | |
dict: A new dictionary with None values removed. | |
""" | |
return {key: value for key, value in input_dict.items() if value is not None} | |
class PPC_SAM(): | |
def __init__(self, model_type="vit_h", | |
ckpt_vit="pretrained_checkpoint/sam_vit_h_4b8939.pth", | |
ckpt_ppc="pretrained_checkpoint/ppc_decoder.pth", | |
ckpt_hq="pretrained_checkpoint/sam_hq_vit_h_decoder.pth", | |
device = "cpu") -> None: | |
# Call the parent class's __init__ method first | |
self.device = device | |
# Initialize the decoders | |
self.sam_hq_decoder = MaskDecoderHQ(model_type) | |
self.ppc_decoder = sam_decoder_reg['default']() | |
# Load state dictionaries | |
model_state_hq = torch.load(ckpt_hq, map_location=device) | |
self.sam_hq_decoder.load_state_dict(model_state_hq) | |
print(f"Loaded HQ decoder checkpoint from {ckpt_hq}") | |
model_state_ppc = torch.load(ckpt_ppc, map_location=device) | |
self.ppc_decoder.load_state_dict(model_state_ppc) | |
print(f"Loaded PPC decoder checkpoint from {ckpt_ppc}") | |
# Initialize the SAM model | |
self.sam = sam_model_registry[model_type](checkpoint=ckpt_vit).to(device) | |
def predict(self, prompts, multimask_ouput=False): | |
with torch.no_grad(): | |
self.sam = self.sam.to(self.device) | |
self.sam_hq_decoder = self.sam_hq_decoder.to(self.device) | |
self.ppc_decoder = self.ppc_decoder.to(self.device) | |
batch_input = remove_none_values(prompts[0]) | |
original_size = batch_input["image"].shape[:2] | |
batch_input["original_size"] = original_size | |
input_image = trans.apply_image(batch_input["image"]) | |
input_image_torch = torch.as_tensor(input_image, device=self.device) | |
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous() | |
batch_input["image"] = input_image_torch | |
if "boxes" in batch_input: | |
batch_input["boxes"] = trans.apply_boxes_torch(batch_input["boxes"], original_size=original_size) | |
if "point_coords" in batch_input: | |
batch_input["point_coords"] = trans.apply_coords_torch(batch_input["point_coords"], original_size=original_size) | |
batched_output, interm_embeddings = self.sam([batch_input], multimask_output=multimask_ouput) | |
batch_len = len(batched_output) | |
encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0) | |
image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)] | |
sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)] | |
dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)] | |
masks_sam_in_hq, masks_hq = self.sam_hq_decoder( | |
image_embeddings=encoder_embedding, | |
image_pe=image_pe, | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
multimask_output=multimask_ouput, | |
hq_token_only=False, | |
interm_embeddings=interm_embeddings, | |
) | |
masks_sam = batched_output[0]["masks"] | |
input_images_ppc = pad_to_square(input_image_torch[None, :,:,:], target_size=1024).float() | |
mask_ppc = self.ppc_decoder(x_img=input_images_ppc, hidden_states_out=interm_embeddings, low_res_mask=masks_hq) | |
rescaled_masks_hq=self.sam.postprocess_masks(masks_hq, input_size=input_image_torch.shape[-2:], original_size=original_size) | |
rescaled_masks_ppc=self.sam.postprocess_masks(mask_ppc, input_size=input_image_torch.shape[-2:], original_size=original_size) | |
stacked_masks = torch.stack([rescaled_masks_ppc, rescaled_masks_hq, masks_sam.to(torch.uint8)], dim=0).cpu().squeeze(1).squeeze(1) | |
return stacked_masks, None, None |