PPC-SAM / load_models.py
forSubAnony's picture
v1
57abc33
from models import MaskDecoderHQ
from ppc_decoder import sam_decoder_reg
from segment_anything import sam_model_registry
import torch.nn as nn
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from utils.transforms import ResizeLongestSide
from typing import List
trans = ResizeLongestSide(target_length=1024)
def save_prob_visualization(prob, filename="prob_visualization.png"):
"""
可视化 1xwxh 的概率图并使用 plt.imshow 保存到本地
:param prob: 形状为 1xwxh 的 tensor
:param filename: 保存的文件名,默认为 'prob_visualization.png'
"""
# 将 prob 转换为 numpy 数组
prob_np = prob.squeeze(0).squeeze(0).numpy() # 从 1xwxh 转为 wxh
# 使用 plt.imshow 可视化
plt.imshow(prob_np)
# , cmap='gray', vmin=0, vmax=1) # cmap='gray' 确保图像以灰度显示
plt.axis('off') # 关闭坐标轴
# 保存图像
plt.savefig(filename, bbox_inches='tight', pad_inches=0)
plt.close()
print(f"Probability map saved as {filename}")
def pad_to_square(x: torch.Tensor, target_size: int) -> torch.Tensor:
"""Pad the input tensor to a square shape with the specified target size."""
# Get the current height and width of the image
h, w = x.shape[-2:]
# Calculate padding for height and width
padh = target_size - h
padw = target_size - w
# Pad the tensor to the target size
x = F.pad(x, (0, padw, 0, padh))
return x
def remove_none_values(input_dict):
"""
Remove all items with None as their value from the dictionary.
Args:
input_dict (dict): The dictionary from which to remove None values.
Returns:
dict: A new dictionary with None values removed.
"""
return {key: value for key, value in input_dict.items() if value is not None}
class PPC_SAM():
def __init__(self, model_type="vit_h",
ckpt_vit="pretrained_checkpoint/sam_vit_h_4b8939.pth",
ckpt_ppc="pretrained_checkpoint/ppc_decoder.pth",
ckpt_hq="pretrained_checkpoint/sam_hq_vit_h_decoder.pth",
device = "cpu") -> None:
# Call the parent class's __init__ method first
self.device = device
# Initialize the decoders
self.sam_hq_decoder = MaskDecoderHQ(model_type)
self.ppc_decoder = sam_decoder_reg['default']()
# Load state dictionaries
model_state_hq = torch.load(ckpt_hq, map_location=device)
self.sam_hq_decoder.load_state_dict(model_state_hq)
print(f"Loaded HQ decoder checkpoint from {ckpt_hq}")
model_state_ppc = torch.load(ckpt_ppc, map_location=device)
self.ppc_decoder.load_state_dict(model_state_ppc)
print(f"Loaded PPC decoder checkpoint from {ckpt_ppc}")
# Initialize the SAM model
self.sam = sam_model_registry[model_type](checkpoint=ckpt_vit).to(device)
def predict(self, prompts, multimask_ouput=False):
with torch.no_grad():
self.sam = self.sam.to(self.device)
self.sam_hq_decoder = self.sam_hq_decoder.to(self.device)
self.ppc_decoder = self.ppc_decoder.to(self.device)
batch_input = remove_none_values(prompts[0])
original_size = batch_input["image"].shape[:2]
batch_input["original_size"] = original_size
input_image = trans.apply_image(batch_input["image"])
input_image_torch = torch.as_tensor(input_image, device=self.device)
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()
batch_input["image"] = input_image_torch
if "boxes" in batch_input:
batch_input["boxes"] = trans.apply_boxes_torch(batch_input["boxes"], original_size=original_size)
if "point_coords" in batch_input:
batch_input["point_coords"] = trans.apply_coords_torch(batch_input["point_coords"], original_size=original_size)
batched_output, interm_embeddings = self.sam([batch_input], multimask_output=multimask_ouput)
batch_len = len(batched_output)
encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0)
image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)]
sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)]
dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)]
masks_sam_in_hq, masks_hq = self.sam_hq_decoder(
image_embeddings=encoder_embedding,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_ouput,
hq_token_only=False,
interm_embeddings=interm_embeddings,
)
masks_sam = batched_output[0]["masks"]
input_images_ppc = pad_to_square(input_image_torch[None, :,:,:], target_size=1024).float()
mask_ppc = self.ppc_decoder(x_img=input_images_ppc, hidden_states_out=interm_embeddings, low_res_mask=masks_hq)
rescaled_masks_hq=self.sam.postprocess_masks(masks_hq, input_size=input_image_torch.shape[-2:], original_size=original_size)
rescaled_masks_ppc=self.sam.postprocess_masks(mask_ppc, input_size=input_image_torch.shape[-2:], original_size=original_size)
stacked_masks = torch.stack([rescaled_masks_ppc, rescaled_masks_hq, masks_sam.to(torch.uint8)], dim=0).cpu().squeeze(1).squeeze(1)
return stacked_masks, None, None