from typing import List, Tuple import os import torch.distributed as dist from torch import Tensor from mmdet.registry import MODELS, TASK_UTILS from mmdet.models.dense_heads import AnchorFreeHead from mmdet.structures import SampleList from mmdet.models.dense_heads import Mask2FormerHead import math from mmengine.model.weight_init import trunc_normal_ import torch from torch import nn import torch.nn.functional as F from mmcv.cnn import build_activation_layer, build_norm_layer from mmengine.dist import get_dist_info @MODELS.register_module() class YOSOHead(Mask2FormerHead): def __init__(self, num_cls_fcs=1, num_mask_fcs=1, sphere_cls=False, ov_classifier_name=None, use_kernel_updator=False, num_stages=3, feat_channels=256, out_channels=256, num_things_classes=80, num_stuff_classes=53, num_classes=133, num_queries=100, temperature=0.1, loss_cls=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=2.0, reduction='mean', class_weight=[1.0] * 133 + [0.1]), loss_mask=dict( type='CrossEntropyLoss', use_sigmoid=True, reduction='mean', loss_weight=5.0), loss_dice=dict( type='DiceLoss', use_sigmoid=True, activate=True, reduction='mean', naive_dice=True, eps=1.0, loss_weight=5.0), train_cfg=None, test_cfg=None, init_cfg=None): super(AnchorFreeHead, self).__init__(init_cfg=init_cfg) self.num_stages = num_stages self.feat_channels = feat_channels self.out_channels = out_channels self.num_things_classes = num_things_classes self.num_stuff_classes = num_stuff_classes self.num_classes = num_classes self.num_queries = num_queries self.temperature = temperature self.test_cfg = test_cfg self.train_cfg = train_cfg if train_cfg: self.assigner = TASK_UTILS.build(self.train_cfg['assigner']) self.sampler = TASK_UTILS.build( self.train_cfg['sampler'], default_args=dict(context=self)) self.num_points = self.train_cfg.get('num_points', 12544) self.oversample_ratio = self.train_cfg.get('oversample_ratio', 3.0) self.importance_sample_ratio = self.train_cfg.get( 'importance_sample_ratio', 0.75) self.class_weight = loss_cls.class_weight self.loss_cls = MODELS.build(loss_cls) self.loss_mask = MODELS.build(loss_mask) self.loss_dice = MODELS.build(loss_dice) self.kernels = nn.Embedding(self.num_queries, self.feat_channels) self.mask_heads = nn.ModuleList() for _ in range(self.num_stages): self.mask_heads.append(CrossAttenHead( self.num_classes, self.feat_channels, self.num_queries, use_kernel_updator=use_kernel_updator, sphere_cls=sphere_cls, ov_classifier_name=ov_classifier_name, num_cls_fcs=num_cls_fcs, num_mask_fcs=num_mask_fcs )) def init_weights(self) -> None: super(AnchorFreeHead, self).init_weights() def forward(self, x: List[Tensor], batch_data_samples: SampleList) -> Tuple[List[Tensor]]: all_cls_scores = [] all_masks_preds = [] proposal_kernels = self.kernels.weight object_kernels = proposal_kernels[None].repeat(x.shape[0], 1, 1) mask_preds = torch.einsum('bnc,bchw->bnhw', object_kernels, x) for stage in range(self.num_stages): mask_head = self.mask_heads[stage] cls_scores, mask_preds, iou_pred, object_kernels = mask_head(x, object_kernels, mask_preds) cls_scores = cls_scores / self.temperature all_cls_scores.append(cls_scores) all_masks_preds.append(mask_preds) return all_cls_scores, all_masks_preds def predict(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> Tuple[Tensor]: batch_img_metas = [ data_sample.metainfo for data_sample in batch_data_samples ] all_cls_scores, all_mask_preds = self(x, batch_data_samples) mask_cls_results = all_cls_scores[-1] mask_pred_results = all_mask_preds[-1] # upsample masks img_shape = batch_img_metas[0]['batch_input_shape'] mask_pred_results = F.interpolate( mask_pred_results, size=(img_shape[0], img_shape[1]), mode='bilinear', align_corners=False) return mask_cls_results, mask_pred_results class FFN(nn.Module): def __init__(self, embed_dims=256, feedforward_channels=1024, num_fcs=2, add_identity=True): super(FFN, self).__init__() self.embed_dims = embed_dims self.feedforward_channels = feedforward_channels self.num_fcs = num_fcs layers = [] in_channels = embed_dims for _ in range(num_fcs - 1): layers.append(nn.Sequential( nn.Linear(in_channels, feedforward_channels), nn.ReLU(True), nn.Dropout(0.0))) in_channels = feedforward_channels layers.append(nn.Linear(feedforward_channels, embed_dims)) layers.append(nn.Dropout(0.0)) self.layers = nn.Sequential(*layers) self.add_identity = add_identity self.dropout_layer = nn.Dropout(0.0) def forward(self, x, identity=None): out = self.layers(x) if not self.add_identity: return self.dropout_layer(out) if identity is None: identity = x return identity + self.dropout_layer(out) class DySepConvAtten(nn.Module): def __init__(self, hidden_dim, num_proposals, conv_kernel_size_1d): super(DySepConvAtten, self).__init__() self.hidden_dim = hidden_dim self.num_proposals = num_proposals self.kernel_size = conv_kernel_size_1d self.weight_linear = nn.Linear(self.hidden_dim, self.num_proposals + self.kernel_size) self.norm = nn.LayerNorm(self.hidden_dim) def forward(self, query, value): assert query.shape == value.shape B, N, C = query.shape dy_conv_weight = self.weight_linear(query) dy_depth_conv_weight = dy_conv_weight[:, :, :self.kernel_size].view(B, self.num_proposals, 1, self.kernel_size) dy_point_conv_weight = dy_conv_weight[:, :, self.kernel_size:].view(B, self.num_proposals, self.num_proposals, 1) res = [] value = value.unsqueeze(1) for i in range(B): out = F.relu(F.conv1d(input=value[i], weight=dy_depth_conv_weight[i], groups=N, padding='same')) out = F.conv1d(input=out, weight=dy_point_conv_weight[i], padding='same') res.append(out) point_out = torch.cat(res, dim=0) point_out = self.norm(point_out) return point_out class KernelUpdator(nn.Module): def __init__(self, in_channels=256, feat_channels=64, out_channels=None, input_feat_shape=3, gate_sigmoid=True, gate_norm_act=False, activate_out=False, act_cfg=dict(type='ReLU', inplace=True), norm_cfg=dict(type='LN')): super(KernelUpdator, self).__init__() self.in_channels = in_channels self.feat_channels = feat_channels self.out_channels_raw = out_channels self.gate_sigmoid = gate_sigmoid self.gate_norm_act = gate_norm_act self.activate_out = activate_out if isinstance(input_feat_shape, int): input_feat_shape = [input_feat_shape] * 2 self.input_feat_shape = input_feat_shape self.act_cfg = act_cfg self.norm_cfg = norm_cfg self.out_channels = out_channels if out_channels else in_channels self.num_params_in = self.feat_channels self.num_params_out = self.feat_channels self.dynamic_layer = nn.Linear( self.in_channels, self.num_params_in + self.num_params_out) self.input_layer = nn.Linear(self.in_channels, self.num_params_in + self.num_params_out, 1) self.input_gate = nn.Linear(self.in_channels, self.feat_channels, 1) self.update_gate = nn.Linear(self.in_channels, self.feat_channels, 1) if self.gate_norm_act: self.gate_norm = build_norm_layer(norm_cfg, self.feat_channels)[1] self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] self.norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] self.input_norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1] self.input_norm_out = build_norm_layer(norm_cfg, self.feat_channels)[1] self.activation = build_activation_layer(act_cfg) self.fc_layer = nn.Linear(self.feat_channels, self.out_channels, 1) self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1] def forward(self, update_feature, input_feature): """ Args: update_feature (torch.Tensor): [bs, num_proposals, in_channels] input_feature (torch.Tensor): [bs, num_proposals, in_channels] """ bs, num_proposals, _ = update_feature.shape parameters = self.dynamic_layer(update_feature) param_in = parameters[..., :self.num_params_in] param_out = parameters[..., -self.num_params_out:] input_feats = self.input_layer(input_feature) input_in = input_feats[..., :self.num_params_in] input_out = input_feats[..., -self.num_params_out:] gate_feats = input_in * param_in if self.gate_norm_act: gate_feats = self.activation(self.gate_norm(gate_feats)) input_gate = self.input_norm_in(self.input_gate(gate_feats)) update_gate = self.norm_in(self.update_gate(gate_feats)) if self.gate_sigmoid: input_gate = input_gate.sigmoid() update_gate = update_gate.sigmoid() param_out = self.norm_out(param_out) input_out = self.input_norm_out(input_out) if self.activate_out: param_out = self.activation(param_out) input_out = self.activation(input_out) # param_out has shape (batch_size, feat_channels, out_channels) features = update_gate * param_out + input_gate * input_out features = self.fc_layer(features) features = self.fc_norm(features) features = self.activation(features) return features class CrossAttenHead(nn.Module): def __init__(self, num_classes, in_channels, num_proposals, frozen_head=False, frozen_pred=False, with_iou_pred=False, sphere_cls=False, ov_classifier_name=None, num_cls_fcs=1, num_mask_fcs=1, conv_kernel_size_1d=3, conv_kernel_size_2d=1, use_kernel_updator=False): super(CrossAttenHead, self).__init__() self.sphere_cls = sphere_cls self.with_iou_pred = with_iou_pred self.frozen_head = frozen_head self.frozen_pred = frozen_pred self.num_cls_fcs = num_cls_fcs self.num_mask_fcs = num_mask_fcs self.num_classes = num_classes self.conv_kernel_size_2d = conv_kernel_size_2d self.hidden_dim = in_channels self.feat_channels = in_channels self.num_proposals = num_proposals self.hard_mask_thr = 0.5 self.use_kernel_updator = use_kernel_updator # assert use_kernel_updator if use_kernel_updator: self.kernel_update = KernelUpdator( in_channels=256, feat_channels=256, out_channels=256, input_feat_shape=3, act_cfg=dict(type='ReLU', inplace=True), norm_cfg=dict(type='LN') ) else: self.f_atten = DySepConvAtten(self.feat_channels, self.num_proposals, conv_kernel_size_1d) self.f_dropout = nn.Dropout(0.0) self.f_atten_norm = nn.LayerNorm(self.hidden_dim * self.conv_kernel_size_2d ** 2) self.k_atten = DySepConvAtten(self.feat_channels, self.num_proposals, conv_kernel_size_1d) self.k_dropout = nn.Dropout(0.0) self.k_atten_norm = nn.LayerNorm(self.hidden_dim * self.conv_kernel_size_2d ** 2) self.s_atten = nn.MultiheadAttention(embed_dim=self.hidden_dim * self.conv_kernel_size_2d ** 2, num_heads=8, dropout=0.0) self.s_dropout = nn.Dropout(0.0) self.s_atten_norm = nn.LayerNorm(self.hidden_dim * self.conv_kernel_size_2d ** 2) self.ffn = FFN(self.hidden_dim, feedforward_channels=2048, num_fcs=2) self.ffn_norm = nn.LayerNorm(self.hidden_dim) self.cls_fcs = nn.ModuleList() for _ in range(self.num_cls_fcs): self.cls_fcs.append(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)) self.cls_fcs.append(nn.LayerNorm(self.hidden_dim)) self.cls_fcs.append(nn.ReLU(True)) if sphere_cls: rank, world_size = get_dist_info() if ov_classifier_name is None: _dim = 1024 # temporally hard code cls_embed = torch.empty(self.num_classes, _dim) torch.nn.init.orthogonal_(cls_embed) cls_embed = cls_embed[:, None] else: ov_path = os.path.join(os.path.expanduser('~/.cache/embd'), f"{ov_classifier_name}.pth") cls_embed = torch.load(ov_path) cls_embed_norm = cls_embed.norm(p=2, dim=-1) assert torch.allclose(cls_embed_norm, torch.ones_like(cls_embed_norm)) # background class _dim = cls_embed.size(2) _prototypes = cls_embed.size(1) if rank == 0: back_token = torch.zeros(1, _dim, dtype=torch.float32, device='cuda') else: back_token = torch.empty(1, _dim, dtype=torch.float32, device='cuda') if world_size > 1: dist.broadcast(back_token, src=0) back_token = back_token.to(device='cpu') cls_embed = torch.cat([ cls_embed, back_token.repeat(_prototypes, 1)[None] ], dim=0) self.register_buffer('fc_cls', cls_embed.permute(2, 0, 1).contiguous(), persistent=False) # cls embd proj cls_embed_dim = self.fc_cls.size(0) self.cls_proj = nn.Sequential( nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(inplace=True), nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(inplace=True), nn.Linear(self.hidden_dim, cls_embed_dim) ) logit_scale = torch.tensor(4.6052, dtype=torch.float32) self.register_buffer('logit_scale', logit_scale, persistent=False) else: self.fc_cls = nn.Linear(self.hidden_dim, self.num_classes + 1) self.mask_fcs = nn.ModuleList() for _ in range(self.num_mask_fcs): self.mask_fcs.append(nn.Linear(self.hidden_dim, self.hidden_dim, bias=False)) self.mask_fcs.append(nn.LayerNorm(self.hidden_dim)) self.mask_fcs.append(nn.ReLU(True)) self.fc_mask = nn.Linear(self.hidden_dim, self.hidden_dim) if self.with_iou_pred: self.iou_embed = nn.Sequential( nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(inplace=True), nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(inplace=True), nn.Linear(self.hidden_dim, 1), ) prior_prob = 0.01 self.bias_value = -math.log((1 - prior_prob) / prior_prob) self.apply(self._init_weights) if not sphere_cls: nn.init.constant_(self.fc_cls.bias, self.bias_value) if self.frozen_head: self._frozen_head() if self.frozen_pred: self._frozen_pred() def _init_weights(self, m): # print("init weights") if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def _frozen_head(self): for n, p in self.kernel_update.named_parameters(): p.requires_grad = False for n, p in self.s_atten.named_parameters(): p.requires_grad = False for n, p in self.s_dropout.named_parameters(): p.requires_grad = False for n, p in self.s_atten_norm.named_parameters(): p.requires_grad = False for n, p in self.ffn.named_parameters(): p.requires_grad = False for n, p in self.ffn_norm.named_parameters(): p.requires_grad = False def _frozen_pred(self): # frozen cls_fcs, fc_cls, mask_fcs, fc_mask for n, p in self.cls_fcs.named_parameters(): p.requires_grad = False for n, p in self.fc_cls.named_parameters(): p.requires_grad = False for n, p in self.mask_fcs.named_parameters(): p.requires_grad = False for n, p in self.fc_mask.named_parameters(): p.requires_grad = False def train(self, mode): super().train(mode) if self.frozen_head: self.kernel_update.eval() self.s_atten.eval() self.s_dropout.eval() self.s_atten_norm.eval() self.ffn.eval() self.ffn_norm.eval() if self.frozen_pred: self.cls_fcs.eval() self.fc_cls.eval() self.mask_fcs.eval() self.fc_mask.eval() def forward(self, features, proposal_kernels, mask_preds, self_attn_mask=None): B, C, H, W = features.shape soft_sigmoid_masks = mask_preds.sigmoid() nonzero_inds = soft_sigmoid_masks > self.hard_mask_thr hard_sigmoid_masks = nonzero_inds.float() # [B, N, C] f = torch.einsum('bnhw,bchw->bnc', hard_sigmoid_masks, features) # [B, N, C, K, K] -> [B, N, C * K * K] num_proposals = proposal_kernels.shape[1] k = proposal_kernels.view(B, num_proposals, -1) # ---- if self.use_kernel_updator: k = self.kernel_update(f, k) else: f_tmp = self.f_atten(k, f) f = f + self.f_dropout(f_tmp) f = self.f_atten_norm(f) f_tmp = self.k_atten(k, f) f = f + self.k_dropout(f_tmp) k = self.k_atten_norm(f) # [N, B, C] k = k.permute(1, 0, 2) k_tmp = self.s_atten(query=k, key=k, value=k, attn_mask=self_attn_mask)[0] k = k + self.s_dropout(k_tmp) k = self.s_atten_norm(k.permute(1, 0, 2)) obj_feat = self.ffn_norm(self.ffn(k)) cls_feat = obj_feat mask_feat = obj_feat for cls_layer in self.cls_fcs: cls_feat = cls_layer(cls_feat) if self.sphere_cls: cls_embd = self.cls_proj(cls_feat) # FIXME Too much cls linear (cls_fcs + cls_proj) cls_score = torch.einsum('bnc,ckp->bnkp', F.normalize(cls_embd, dim=-1), self.fc_cls) cls_score = cls_score.max(-1).values cls_score = self.logit_scale.exp() * cls_score else: cls_score = self.fc_cls(cls_feat) for reg_layer in self.mask_fcs: mask_feat = reg_layer(mask_feat) # [B, N, K * K, C] -> [B, N, C] mask_kernels = self.fc_mask(mask_feat) new_mask_preds = torch.einsum("bqc,bchw->bqhw", mask_kernels, features) if self.with_iou_pred: iou_pred = self.iou_embed(mask_feat) iou_pred = iou_pred else: iou_pred = None return cls_score, new_mask_preds, iou_pred, obj_feat