from models import MaskDecoderHQ from ppc_decoder import sam_decoder_reg from segment_anything import sam_model_registry import torch.nn as nn import torch import torch.nn.functional as F import matplotlib.pyplot as plt from utils.transforms import ResizeLongestSide from typing import List trans = ResizeLongestSide(target_length=1024) def save_prob_visualization(prob, filename="prob_visualization.png"): """ 可视化 1xwxh 的概率图并使用 plt.imshow 保存到本地 :param prob: 形状为 1xwxh 的 tensor :param filename: 保存的文件名,默认为 'prob_visualization.png' """ # 将 prob 转换为 numpy 数组 prob_np = prob.squeeze(0).squeeze(0).numpy() # 从 1xwxh 转为 wxh # 使用 plt.imshow 可视化 plt.imshow(prob_np) # , cmap='gray', vmin=0, vmax=1) # cmap='gray' 确保图像以灰度显示 plt.axis('off') # 关闭坐标轴 # 保存图像 plt.savefig(filename, bbox_inches='tight', pad_inches=0) plt.close() print(f"Probability map saved as {filename}") def pad_to_square(x: torch.Tensor, target_size: int) -> torch.Tensor: """Pad the input tensor to a square shape with the specified target size.""" # Get the current height and width of the image h, w = x.shape[-2:] # Calculate padding for height and width padh = target_size - h padw = target_size - w # Pad the tensor to the target size x = F.pad(x, (0, padw, 0, padh)) return x def remove_none_values(input_dict): """ Remove all items with None as their value from the dictionary. Args: input_dict (dict): The dictionary from which to remove None values. Returns: dict: A new dictionary with None values removed. """ return {key: value for key, value in input_dict.items() if value is not None} class PPC_SAM(): def __init__(self, model_type="vit_h", ckpt_vit="pretrained_checkpoint/sam_vit_h_4b8939.pth", ckpt_ppc="pretrained_checkpoint/ppc_decoder.pth", ckpt_hq="pretrained_checkpoint/sam_hq_vit_h_decoder.pth", device = "cpu") -> None: # Call the parent class's __init__ method first self.device = device # Initialize the decoders self.sam_hq_decoder = MaskDecoderHQ(model_type) self.ppc_decoder = sam_decoder_reg['default']() # Load state dictionaries model_state_hq = torch.load(ckpt_hq, map_location=device) self.sam_hq_decoder.load_state_dict(model_state_hq) print(f"Loaded HQ decoder checkpoint from {ckpt_hq}") model_state_ppc = torch.load(ckpt_ppc, map_location=device) self.ppc_decoder.load_state_dict(model_state_ppc) print(f"Loaded PPC decoder checkpoint from {ckpt_ppc}") # Initialize the SAM model self.sam = sam_model_registry[model_type](checkpoint=ckpt_vit).to(device) def predict(self, prompts, multimask_ouput=False): with torch.no_grad(): self.sam = self.sam.to(self.device) self.sam_hq_decoder = self.sam_hq_decoder.to(self.device) self.ppc_decoder = self.ppc_decoder.to(self.device) batch_input = remove_none_values(prompts[0]) original_size = batch_input["image"].shape[:2] batch_input["original_size"] = original_size input_image = trans.apply_image(batch_input["image"]) input_image_torch = torch.as_tensor(input_image, device=self.device) input_image_torch = input_image_torch.permute(2, 0, 1).contiguous() batch_input["image"] = input_image_torch if "boxes" in batch_input: batch_input["boxes"] = trans.apply_boxes_torch(batch_input["boxes"], original_size=original_size) if "point_coords" in batch_input: batch_input["point_coords"] = trans.apply_coords_torch(batch_input["point_coords"], original_size=original_size) batched_output, interm_embeddings = self.sam([batch_input], multimask_output=multimask_ouput) batch_len = len(batched_output) encoder_embedding = torch.cat([batched_output[i_l]['encoder_embedding'] for i_l in range(batch_len)], dim=0) image_pe = [batched_output[i_l]['image_pe'] for i_l in range(batch_len)] sparse_embeddings = [batched_output[i_l]['sparse_embeddings'] for i_l in range(batch_len)] dense_embeddings = [batched_output[i_l]['dense_embeddings'] for i_l in range(batch_len)] masks_sam_in_hq, masks_hq = self.sam_hq_decoder( image_embeddings=encoder_embedding, image_pe=image_pe, sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_ouput, hq_token_only=False, interm_embeddings=interm_embeddings, ) masks_sam = batched_output[0]["masks"] input_images_ppc = pad_to_square(input_image_torch[None, :,:,:], target_size=1024).float() mask_ppc = self.ppc_decoder(x_img=input_images_ppc, hidden_states_out=interm_embeddings, low_res_mask=masks_hq) rescaled_masks_hq=self.sam.postprocess_masks(masks_hq, input_size=input_image_torch.shape[-2:], original_size=original_size) rescaled_masks_ppc=self.sam.postprocess_masks(mask_ppc, input_size=input_image_torch.shape[-2:], original_size=original_size) stacked_masks = torch.stack([rescaled_masks_ppc, rescaled_masks_hq, masks_sam.to(torch.uint8)], dim=0).cpu().squeeze(1).squeeze(1) return stacked_masks, None, None