Spaces:
Running
Running
import logging | |
from functools import partial | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from models import register | |
from .mmseg.models.sam import ImageEncoderViT, MaskDecoder, TwoWayTransformer | |
logger = logging.getLogger(__name__) | |
from .iou_loss import IOU | |
from typing import Any, Optional, Tuple | |
from .mmseg.models.sam import PromptEncoder | |
def init_weights(layer): | |
if type(layer) == nn.Conv2d: | |
nn.init.normal_(layer.weight, mean=0.0, std=0.02) | |
nn.init.constant_(layer.bias, 0.0) | |
elif type(layer) == nn.Linear: | |
nn.init.normal_(layer.weight, mean=0.0, std=0.02) | |
nn.init.constant_(layer.bias, 0.0) | |
elif type(layer) == nn.BatchNorm2d: | |
# print(layer) | |
nn.init.normal_(layer.weight, mean=1.0, std=0.02) | |
nn.init.constant_(layer.bias, 0.0) | |
class BBCEWithLogitLoss(nn.Module): | |
''' | |
Balanced BCEWithLogitLoss | |
''' | |
def __init__(self): | |
super(BBCEWithLogitLoss, self).__init__() | |
def forward(self, pred, gt): | |
eps = 1e-10 | |
count_pos = torch.sum(gt) + eps | |
count_neg = torch.sum(1. - gt) | |
ratio = count_neg / count_pos | |
w_neg = count_pos / (count_pos + count_neg) | |
bce1 = nn.BCEWithLogitsLoss(pos_weight=ratio) | |
loss = w_neg * bce1(pred, gt) | |
return loss | |
def _iou_loss(pred, target): | |
print('*****&&&', pred.shape, target.shape) | |
pred = torch.sigmoid(pred) | |
inter = (pred * target).sum(dim=(2, 3)) | |
union = (pred + target).sum(dim=(2, 3)) - inter | |
iou = 1 - (inter / union) | |
return iou.mean() | |
class PositionEmbeddingRandom(nn.Module): | |
""" | |
Positional encoding using random spatial frequencies. | |
""" | |
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: | |
super().__init__() | |
if scale is None or scale <= 0.0: | |
scale = 1.0 | |
self.register_buffer( | |
"positional_encoding_gaussian_matrix", | |
scale * torch.randn((2, num_pos_feats)), | |
) | |
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: | |
"""Positionally encode points that are normalized to [0,1].""" | |
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape | |
coords = 2 * coords - 1 | |
coords = coords @ self.positional_encoding_gaussian_matrix | |
coords = 2 * np.pi * coords | |
# outputs d_1 x ... x d_n x C shape | |
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) | |
def forward(self, size: int) -> torch.Tensor: | |
"""Generate positional encoding for a grid of the specified size.""" | |
h, w = size, size | |
device: Any = self.positional_encoding_gaussian_matrix.device | |
grid = torch.ones((h, w), device=device, dtype=torch.float32) | |
y_embed = grid.cumsum(dim=0) - 0.5 | |
x_embed = grid.cumsum(dim=1) - 0.5 | |
y_embed = y_embed / h | |
x_embed = x_embed / w | |
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) | |
return pe.permute(2, 0, 1) # C x H x W | |
class SAM(nn.Module): | |
def __init__(self, inp_size=None, encoder_mode=None, loss=None): | |
super().__init__() | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.embed_dim = encoder_mode['embed_dim'] | |
self.image_encoder = ImageEncoderViT( | |
img_size=inp_size, | |
patch_size=encoder_mode['patch_size'], | |
in_chans=3, | |
embed_dim=encoder_mode['embed_dim'], | |
depth=encoder_mode['depth'], | |
num_heads=encoder_mode['num_heads'], | |
mlp_ratio=encoder_mode['mlp_ratio'], | |
out_chans=encoder_mode['out_chans'], | |
qkv_bias=encoder_mode['qkv_bias'], | |
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), | |
act_layer=nn.GELU, | |
use_rel_pos=encoder_mode['use_rel_pos'], | |
rel_pos_zero_init=True, | |
window_size=encoder_mode['window_size'], | |
global_attn_indexes=encoder_mode['global_attn_indexes'], | |
) | |
self.prompt_embed_dim = encoder_mode['prompt_embed_dim']#256 | |
prompt_embed_dim = 256 | |
image_embedding_size = inp_size / 16 | |
self.prompt_encoder = PromptEncoder( | |
embed_dim=prompt_embed_dim, | |
image_embedding_size=(int(image_embedding_size), int(image_embedding_size)), | |
input_image_size=(inp_size, inp_size), | |
mask_in_chans=16, | |
) | |
self.mask_decoder = MaskDecoder( | |
# num_multimask_outputs=3, | |
# num_multimask_outputs=15,#iasid | |
# num_multimask_outputs=5, | |
# num_multimask_outputs=25, | |
# num_multimask_outputs=14, | |
# num_multimask_outputs=12, | |
# num_multimask_outputs=14, | |
num_multimask_outputs=26, | |
transformer=TwoWayTransformer( | |
depth=2, | |
embedding_dim=self.prompt_embed_dim, | |
mlp_dim=2048, | |
num_heads=8, | |
), | |
transformer_dim=self.prompt_embed_dim, | |
iou_head_depth=3, | |
iou_head_hidden_dim=256, | |
) | |
# self.mask_decoder_diwu = MaskDecoder( | |
# # num_multimask_outputs=3, | |
# # num_multimask_outputs=15,#iasid | |
# # num_multimask_outputs=5, | |
# # num_multimask_outputs=25, | |
# num_multimask_outputs=12, | |
# # num_multimask_outputs=26, | |
# transformer=TwoWayTransformer( | |
# depth=2, | |
# embedding_dim=self.prompt_embed_dim, | |
# mlp_dim=2048, | |
# num_heads=8, | |
# ), | |
# transformer_dim=self.prompt_embed_dim, | |
# iou_head_depth=3, | |
# iou_head_hidden_dim=256, | |
# ) | |
if 'evp' in encoder_mode['name']: | |
for k, p in self.encoder.named_parameters(): | |
if "prompt" not in k and "mask_decoder" not in k and "prompt_encoder" not in k: | |
p.requires_grad = False | |
self.loss_mode = loss | |
if self.loss_mode == 'bce': | |
self.criterionBCE = torch.nn.BCEWithLogitsLoss() | |
elif self.loss_mode == 'bbce': | |
self.criterionBCE = BBCEWithLogitLoss() | |
elif self.loss_mode == 'iou': | |
self.criterionBCE = torch.nn.BCEWithLogitsLoss() | |
self.criterionIOU = IOU() | |
elif self.loss_mode == 'cr': | |
# self.criterionCR = torch.nn.CrossEntropyLoss(ignore_index=255, reduction='mean') | |
self.criterionCR = torch.nn.CrossEntropyLoss(ignore_index=25, reduction='mean') | |
# 背景类不参与计算loss | |
self.criterionIOU = IOU() | |
self.pe_layer = PositionEmbeddingRandom(encoder_mode['prompt_embed_dim'] // 2) | |
self.inp_size = inp_size | |
self.image_embedding_size = inp_size // encoder_mode['patch_size']#1024/16 | |
self.no_mask_embed = nn.Embedding(1, encoder_mode['prompt_embed_dim'])#256 | |
def set_input(self, input, gt_mask): | |
self.input = input.to(self.device) | |
self.gt_mask = gt_mask.to(self.device) | |
def get_dense_pe(self) -> torch.Tensor: | |
""" | |
Returns the positional encoding used to encode point prompts, | |
applied to a dense set of points the shape of the image encoding. | |
Returns: | |
torch.Tensor: Positional encoding with shape | |
1x(embed_dim)x(embedding_h)x(embedding_w) | |
""" | |
return self.pe_layer(self.image_embedding_size).unsqueeze(0) | |
def forward(self): | |
bs = 1 | |
# Embed prompts | |
sparse_embeddings = torch.empty((bs, 0, self.prompt_embed_dim), device=self.input.device)#空tensor | |
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( | |
bs, -1, self.image_embedding_size, self.image_embedding_size | |
) | |
#提取 image embedding | |
# print('-----input-----',self.input.shape) | |
self.features = self.image_encoder(self.input) #最后一层输出 | |
# print('-----image emded-----', self.features.shape) | |
# Predict masks | |
low_res_masks, iou_predictions = self.mask_decoder( | |
image_embeddings=self.features, | |
image_pe=self.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
# multimask_output=False, | |
multimask_output=True, | |
)#B*C+1*H*W | |
# low_res_masks_2, iou_predictions_2 = self.mask_decoder_diwu( | |
# image_embeddings=self.features, | |
# image_pe=self.get_dense_pe(), | |
# sparse_prompt_embeddings=sparse_embeddings, | |
# dense_prompt_embeddings=dense_embeddings, | |
# # multimask_output=False, | |
# multimask_output=True, | |
# )#B*C+1*H*W | |
# print('----before cat',low_res_masks.shape, low_res_masks_2.shape) | |
# low_res_masks = torch.cat((low_res_masks, low_res_masks_2), 1) | |
# print('----beshind cat',low_res_masks.shape) | |
# Upscale the masks to the original image resolution | |
masks = self.postprocess_masks(low_res_masks, self.inp_size, self.inp_size) | |
self.pred_mask = masks | |
def infer(self, input): | |
bs = 1 | |
# Embed prompts | |
sparse_embeddings = torch.empty((bs, 0, self.prompt_embed_dim), device=input.device) | |
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( | |
bs, -1, self.image_embedding_size, self.image_embedding_size | |
) | |
self.features = self.image_encoder(input) | |
# Predict masks | |
low_res_masks, iou_predictions = self.mask_decoder( | |
image_embeddings=self.features, | |
image_pe=self.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embeddings, | |
dense_prompt_embeddings=dense_embeddings, | |
# multimask_output=False, | |
multimask_output=True, | |
)#b*1*256*256 | |
# low_res_masks_2, iou_predictions_2 = self.mask_decoder_diwu( | |
# image_embeddings=self.features, | |
# image_pe=self.get_dense_pe(), | |
# sparse_prompt_embeddings=sparse_embeddings, | |
# dense_prompt_embeddings=dense_embeddings, | |
# # multimask_output=False, | |
# multimask_output=True, | |
# ) # B*C+1*H*W | |
# print('----before cat',low_res_masks.shape, low_res_masks_2.shape) | |
# low_res_masks = torch.cat((low_res_masks, low_res_masks_2), 1) | |
# Upscale the masks to the original image resolution | |
#b*1*1024*1024 | |
masks = self.postprocess_masks(low_res_masks, self.inp_size, self.inp_size)#上采样至原图大小 | |
# masks = masks.sigmoid() | |
return masks | |
def postprocess_masks( | |
self, | |
masks: torch.Tensor, | |
input_size: Tuple[int, ...], | |
original_size: Tuple[int, ...], | |
) -> torch.Tensor: | |
""" | |
Remove padding and upscale masks to the original image size. | |
Arguments: | |
masks (torch.Tensor): Batched masks from the mask_decoder, | |
in BxCxHxW format. | |
input_size (tuple(int, int)): The size of the image input to the | |
model, in (H, W) format. Used to remove padding. | |
original_size (tuple(int, int)): The original size of the image | |
before resizing for input to the model, in (H, W) format. | |
Returns: | |
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) | |
is given by original_size. | |
""" | |
masks = F.interpolate( | |
masks, | |
(self.image_encoder.img_size, self.image_encoder.img_size), | |
mode="bilinear", | |
align_corners=False, | |
) | |
masks = masks[..., : input_size, : input_size] | |
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) | |
return masks | |
def backward_G(self): | |
"""Calculate GAN and L1 loss for the generator""" | |
# self.loss_G = self.criterionBCE(self.pred_mask, self.gt_mask) | |
# if self.loss_mode == 'iou': | |
# self.loss_G += _iou_loss(self.pred_mask, self.gt_mask) | |
# print('^&&&*###',self.pred_mask.shape, self.gt_mask.shape) | |
# print(torch.unique(self.gt_mask)) | |
self.loss_G = self.criterionCR(self.pred_mask, self.gt_mask.squeeze(1).long()) | |
# if self.loss_mode == 'cr': | |
# self.loss_G += _iou_loss(self.pred_mask, self.gt_mask) | |
# print('***selg gt masks',torch.unique(self.gt_mask)) | |
# print('####', self.loss_G) | |
self.loss_G.backward() | |
def _backward_(self, pred_mask, gt_mask): | |
self.loss_G = self.criterionCR(pred_mask, gt_mask.squeeze(1).long()) | |
self.loss_G.backward() | |
def optimize_parameters(self): | |
self.forward() | |
self.optimizer.zero_grad() # set G's gradients to zero | |
self.backward_G() # calculate graidents for G | |
self.optimizer.step() # udpate G's weights | |
def preprocess(self, x: torch.Tensor) -> torch.Tensor: | |
"""Normalize pixel values and pad to a square input.""" | |
# Normalize colors | |
x = (x - self.pixel_mean) / self.pixel_std | |
# Pad | |
h, w = x.shape[-2:] | |
padh = self.image_encoder.img_size - h | |
padw = self.image_encoder.img_size - w | |
x = F.pad(x, (0, padw, 0, padh)) | |
return x | |
def set_requires_grad(self, nets, requires_grad=False): | |
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations | |
Parameters: | |
nets (network list) -- a list of networks | |
requires_grad (bool) -- whether the networks require gradients or not | |
""" | |
if not isinstance(nets, list): | |
nets = [nets] | |
for net in nets: | |
if net is not None: | |
for param in net.parameters(): | |
param.requires_grad = requires_grad | |