import torch import torch.nn.functional as F from torch import nn from collections import defaultdict from .inference import make_atss_postprocessor from .loss import make_atss_loss_evaluator from .anchor_generator import make_anchor_generator_complex from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist from maskrcnn_benchmark.layers import Scale, DYReLU, SELayer, ModulatedDeformConv from maskrcnn_benchmark.layers import NaiveSyncBatchNorm2d, FrozenBatchNorm2d from maskrcnn_benchmark.modeling.backbone.fbnet import * from maskrcnn_benchmark.engine.inference import create_positive_map_label_to_token_from_positive_map from ..utils import cat, concat_box_prediction_layers, permute_and_flatten from maskrcnn_benchmark.utils.fuse_helper import FeatureResizer, func_attention, _make_mlp, _make_conv, _make_coord, \ BiAttentionBlock, AttentionT2I, BiAttentionBlockForCheckpoint, BertLMPredictionHead from transformers.models.bert.modeling_bert import BertConfig, BertAttention, BertIntermediate, BertOutput, \ BertPreTrainedModel from transformers.modeling_utils import apply_chunking_to_forward import torch.utils.checkpoint as checkpoint import pdb from maskrcnn_benchmark.modeling.language_backbone.clip_model import QuickGELU, LayerNorm, DropPath from timm.models.layers import DropPath, trunc_normal_ class h_sigmoid(nn.Module): def __init__(self, inplace=True, h_max=1): super(h_sigmoid, self).__init__() self.relu = nn.ReLU6(inplace=inplace) self.h_max = h_max def forward(self, x): return self.relu(x + 3) * self.h_max / 6 class BoxCoder(object): def __init__(self, cfg): self.cfg = cfg def encode(self, gt_boxes, anchors): TO_REMOVE = 1 # TODO remove ex_widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE ex_heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE ex_ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2 ex_ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2 gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0] + TO_REMOVE gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1] + TO_REMOVE gt_ctr_x = (gt_boxes[:, 2] + gt_boxes[:, 0]) / 2 gt_ctr_y = (gt_boxes[:, 3] + gt_boxes[:, 1]) / 2 wx, wy, ww, wh = (10., 10., 5., 5.) targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights targets_dw = ww * torch.log(gt_widths / ex_widths) targets_dh = wh * torch.log(gt_heights / ex_heights) targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) return targets def decode(self, preds, anchors): anchors = anchors.to(preds.dtype) TO_REMOVE = 1 # TODO remove widths = anchors[:, 2] - anchors[:, 0] + TO_REMOVE heights = anchors[:, 3] - anchors[:, 1] + TO_REMOVE ctr_x = (anchors[:, 2] + anchors[:, 0]) / 2 ctr_y = (anchors[:, 3] + anchors[:, 1]) / 2 wx, wy, ww, wh = (10., 10., 5., 5.) dx = preds[:, 0::4] / wx dy = preds[:, 1::4] / wy dw = preds[:, 2::4] / ww dh = preds[:, 3::4] / wh # Prevent sending too large values into torch.exp() dw = torch.clamp(dw, max=math.log(1000. / 16)) dh = torch.clamp(dh, max=math.log(1000. / 16)) pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] pred_w = torch.exp(dw) * widths[:, None] pred_h = torch.exp(dh) * heights[:, None] pred_boxes = torch.zeros_like(preds) pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * (pred_w - 1) pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * (pred_h - 1) pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * (pred_w - 1) pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * (pred_h - 1) return pred_boxes class Conv3x3Norm(torch.nn.Module): def __init__(self, in_channels, out_channels, stride, groups=1, deformable=False, bn_type=None): super(Conv3x3Norm, self).__init__() if deformable: self.conv = ModulatedDeformConv(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, groups=groups) else: self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, groups=groups) if isinstance(bn_type, (list, tuple)): assert len(bn_type) == 2 assert bn_type[0] == "gn" gn_group = bn_type[1] bn_type = bn_type[0] if bn_type == "bn": bn_op = nn.BatchNorm2d(out_channels) elif bn_type == "sbn": bn_op = nn.SyncBatchNorm(out_channels) elif bn_type == "nsbn": bn_op = NaiveSyncBatchNorm2d(out_channels) elif bn_type == "gn": bn_op = nn.GroupNorm(num_groups=gn_group, num_channels=out_channels) elif bn_type == "af": bn_op = FrozenBatchNorm2d(out_channels) if bn_type is not None: self.bn = bn_op else: self.bn = None def forward(self, input, **kwargs): x = self.conv(input, **kwargs) if self.bn: x = self.bn(x) return x class DyConv(torch.nn.Module): def __init__(self, in_channels=256, out_channels=256, conv_func=nn.Conv2d, use_dyfuse=True, use_dyrelu=False, use_deform=False ): super(DyConv, self).__init__() self.DyConv = nn.ModuleList() self.DyConv.append(conv_func(in_channels, out_channels, 1)) self.DyConv.append(conv_func(in_channels, out_channels, 1)) self.DyConv.append(conv_func(in_channels, out_channels, 2)) if use_dyfuse: self.AttnConv = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, 1, kernel_size=1), nn.ReLU(inplace=True)) self.h_sigmoid = h_sigmoid() else: self.AttnConv = None if use_dyrelu: self.relu = DYReLU(in_channels, out_channels) else: self.relu = nn.ReLU() if use_deform: self.offset = nn.Conv2d(in_channels, 27, kernel_size=3, stride=1, padding=1) else: self.offset = None self.init_weights() def init_weights(self): for m in self.DyConv.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight.data, 0, 0.01) if m.bias is not None: m.bias.data.zero_() if self.AttnConv is not None: for m in self.AttnConv.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight.data, 0, 0.01) if m.bias is not None: m.bias.data.zero_() def forward(self, inputs): visual_feats = inputs["visual"] language_dict_features = inputs["lang"] next_x = [] for level, feature in enumerate(visual_feats): conv_args = dict() if self.offset is not None: offset_mask = self.offset(feature) offset = offset_mask[:, :18, :, :] mask = offset_mask[:, 18:, :, :].sigmoid() conv_args = dict(offset=offset, mask=mask) temp_fea = [self.DyConv[1](feature, **conv_args)] if level > 0: temp_fea.append(self.DyConv[2](visual_feats[level - 1], **conv_args)) if level < len(visual_feats) - 1: temp_fea.append(F.upsample_bilinear(self.DyConv[0](visual_feats[level + 1], **conv_args), size=[feature.size(2), feature.size(3)])) mean_fea = torch.mean(torch.stack(temp_fea), dim=0, keepdim=False) if self.AttnConv is not None: attn_fea = [] res_fea = [] for fea in temp_fea: res_fea.append(fea) attn_fea.append(self.AttnConv(fea)) res_fea = torch.stack(res_fea) spa_pyr_attn = self.h_sigmoid(torch.stack(attn_fea)) mean_fea = torch.mean(res_fea * spa_pyr_attn, dim=0, keepdim=False) next_x.append(mean_fea) next_x = [self.relu(item) for item in next_x] features_dict = {"visual": next_x, "lang": language_dict_features} return features_dict class BertEncoderLayer(BertPreTrainedModel): def __init__(self, config, clamp_min_for_underflow = False, clamp_max_for_overflow = False): super().__init__(config) self.config = config self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 from maskrcnn_benchmark.modeling.rpn.modeling_bert import BertAttention, BertIntermediate, BertOutput self.attention = BertAttention(config, clamp_min_for_underflow, clamp_max_for_overflow) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) def forward(self, inputs): language_dict_features = inputs["lang"] hidden_states = language_dict_features["hidden"] attention_mask = language_dict_features["masks"] device = hidden_states.device input_shape = hidden_states.size()[:-1] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device) self_attention_outputs = self.attention( hidden_states, extended_attention_mask, None, output_attentions=False, past_key_value=None, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs hidden_states = outputs[0] language_dict_features["hidden"] = hidden_states features_dict = {"visual": inputs["visual"], "lang": language_dict_features } return features_dict def feed_forward_chunk(self, attention_output): intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output class CLIPTransformerLayer(nn.Module): def __init__(self, config): super().__init__() self.config = config d_model = self.config.MODEL.CLIP.WIDTH n_head = self.config.MODEL.CLIP.HEADS drop_path = self.config.MODEL.CLIP.DROP_PATH self.context_length = self.config.MODEL.CLIP.CONTEXT_LENGTH self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.attn_mask = None self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Linear, nn.Conv2d)): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) def attention(self, x: torch.Tensor, key_padding_mask: torch.Tensor = None): self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) \ if self.attn_mask is not None else None return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask, key_padding_mask=key_padding_mask)[0] def forward(self, inputs): language_dict_features = inputs["lang"] x = language_dict_features["hidden"] mask = language_dict_features["masks"] # get extended attention mask for nn.MultiHeadAttention key_padding_mask = (1.0 - mask).to(torch.bool) x = x.permute(1, 0, 2) x = x + self.drop_path(self.attention(self.ln_1(x), key_padding_mask=key_padding_mask)) x = x + self.drop_path(self.mlp(self.ln_2(x))) x = x.permute(1, 0, 2) language_dict_features["hidden"] = x features_dict = {"visual": inputs["visual"], "lang": language_dict_features } return features_dict class DummyLayer(nn.Module): def __init__(self): super().__init__() def forward(self, inputs): return inputs class VLFuse(torch.nn.Module): """ Early Fusion Module """ def __init__(self, cfg): super(VLFuse, self).__init__() self.init_configs(cfg) self.cfg = cfg self.use_checkpoint = False if hasattr(cfg.MODEL.DYHEAD, 'USE_CHECKPOINT'): self.use_checkpoint = cfg.MODEL.DYHEAD.USE_CHECKPOINT self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True) # early fusion module print("EARLY FUSION ON, USING {}".format(cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE)) if cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-S": # single-direction (text->image) # text -> image self.t2i_attn = AttentionT2I(q_dim=self.joint_embedding_size, k_dim=self.lang_dim, embed_dim=self.embed_dim, num_heads=self.n_head, hidden_dim=self.t2i_hidden_dim, dropout=0.1, drop_path=.0, init_values=1.0 / cfg.MODEL.DYHEAD.NUM_CONVS, mode="t2i", use_layer_scale=cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_LAYER_SCALE, clamp_min_for_underflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MIN_FOR_UNDERFLOW, clamp_max_for_overflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_MAX_FOR_OVERFLOW ) elif cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-B": # bi-direction (text->image, image->text) self.b_attn = BiAttentionBlockForCheckpoint(v_dim=self.joint_embedding_size, l_dim=self.lang_dim, embed_dim=self.embed_dim, num_heads=self.n_head, hidden_dim=self.i2t_hidden_dim, dropout=0.1, drop_path=.0, init_values=1.0 / cfg.MODEL.DYHEAD.NUM_CONVS, cfg=cfg ) if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SEPARATE_BIDIRECTIONAL and self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DO_LANG_PROJ_OUTSIDE_CHECKPOINT: self.shrink_lang = FeatureResizer(self.lang_dim * 5, self.lang_dim, 0.1) elif cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "SCAN": # single-direction (text->image) self.mapping_lang = _make_mlp(self.lang_dim, self.joint_embedding_size, self.joint_embedding_dropout) self.joint_fusion = nn.ModuleList([_make_conv(self.joint_inp_dim, self.joint_out_dim, 1) \ for _ in range(5)]) elif cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "FILM": # single-direction (text->image) self.mapping_lang = _make_mlp(self.lang_dim, self.joint_embedding_size, self.joint_embedding_dropout) self.gamma = nn.ModuleList(nn.Linear(self.joint_embedding_size, self.joint_inp_dim) for _ in range(5)) self.beta = nn.ModuleList(nn.Linear(self.joint_embedding_size, self.joint_inp_dim) for _ in range(5)) self.joint_fusion = nn.ModuleList([_make_conv(self.joint_inp_dim, self.joint_out_dim, 1) \ for _ in range(5)]) else: print("NO FUSION INVOLVED.") def init_configs(self, cfg): # common params self.lang_model = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE self.joint_embedding_size = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_SIZE self.joint_embedding_dropout = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_DROPOUT self.joint_mlp_layers = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_MLP_LAYERS self.max_query_len = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN self.n_layers = cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS self.coord_dim = 8 self.joint_inp_dim = self.coord_dim + self.joint_embedding_size self.joint_out_dim = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_OUT_SIZE # mha params self.n_head = 8 self.embed_dim = 2048 self.t2i_hidden_dim = 1024 # 256 * 4 self.i2t_hidden_dim = 3072 # 768 * 4 if self.lang_model in ["bert-base-uncased", "roberta-base", "clip"]: self.lang_dim = cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM else: self.lang_dim = 1024 def forward(self, x): visual_features = x["visual"] language_dict_features = x["lang"] batch_size = visual_features[0].shape[0] device = visual_features[0].device fused_visual_features = None fused_language_dict_features = None if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-S": language_feature = language_dict_features['hidden'] mask = language_dict_features['masks'] # text -> image if self.use_checkpoint: q0, q1, q2, q3, q4 = checkpoint.checkpoint( self.t2i_attn, visual_features[0], visual_features[1], visual_features[2], visual_features[3], visual_features[4], language_feature, language_feature, mask, self.dummy_tensor ) else: q0, q1, q2, q3, q4 = self.t2i_attn( visual_features[0], visual_features[1], visual_features[2], visual_features[3], visual_features[4], language_feature, language_feature, attention_mask=mask ) fused_visual_features = [q0, q1, q2, q3, q4] fused_language_dict_features = language_dict_features elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "MHA-B": if self.use_checkpoint: q0, q1, q2, q3, q4, l0, l1, l2, l3, l4 = checkpoint.checkpoint(self.b_attn, visual_features[0], visual_features[1], visual_features[2], visual_features[3], visual_features[4], language_dict_features['hidden'], language_dict_features['masks'], self.dummy_tensor ) else: q0, q1, q2, q3, q4, l0, l1, l2, l3, l4 = self.b_attn( visual_features[0], visual_features[1], visual_features[2], visual_features[3], visual_features[4], language_dict_features['hidden'], language_dict_features['masks'], self.dummy_tensor ) fused_visual_features = [q0, q1, q2, q3, q4] if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SEPARATE_BIDIRECTIONAL and self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DO_LANG_PROJ_OUTSIDE_CHECKPOINT: language_features = self.shrink_lang(torch.cat([l0, l1, l2, l3, l4], dim = -1)) else: language_features = l0 language_dict_features['hidden'] = language_features fused_language_dict_features = language_dict_features elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "SCAN": # text -> image language_feature = language_dict_features['aggregate'] language_feature = self.mapping_lang(language_feature) visu_feat = [] for ii, feat in enumerate(visual_features): attn_feat = func_attention(feat, language_feature, smooth=1, raw_feature_norm="softmax") visu_feat.append(attn_feat) fused_visual_features = [fusion(feat) for feat, fusion in zip(visu_feat, self.joint_fusion)] fused_language_dict_features = language_dict_features elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TYPE == "FILM": # text -> image # relative position embedding coord_feats = [_make_coord(batch_size, x.shape[2], x.shape[3]) for x in visual_features] # I only use a global representation of language # you can also use more complex modeling using word-level representations # Usage: lang_feat = lang_feat['words'] shape [seq_len, dim] language_feature = language_dict_features['aggregate'] language_feature = self.mapping_lang(language_feature) # attention mechanism for fusion gamma = [F.tanh(gamma(language_feature)) for gamma in self.gamma] beta = [F.tanh(beta(language_feature)) for beta in self.beta] visu_feat = [] for ii, feat in enumerate(visual_features): coord_feat = coord_feats[ii].to(device) feat = torch.cat([feat, coord_feat], dim=1) b = beta[ii].view(batch_size, -1, 1, 1).expand_as(feat) g = gamma[ii].view(batch_size, -1, 1, 1).expand_as(feat) feat = F.relu(g * feat + b) visu_feat.append(feat) fused_visual_features = [fusion(feat) for feat, fusion in zip(visu_feat, self.joint_fusion)] fused_language_dict_features = language_dict_features else: fused_visual_features = visual_features fused_language_dict_features = language_dict_features features_dict = {"visual": fused_visual_features, "lang": fused_language_dict_features} return features_dict class VLDyHead(torch.nn.Module): def __init__(self, cfg): super(VLDyHead, self).__init__() self.cfg = cfg # bert_cfg = BertConfig.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE) if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "bert-base-uncased": lang_cfg = BertConfig.from_pretrained(cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE) elif cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "clip": lang_cfg = cfg else: lang_cfg = None raise NotImplementedError num_classes = cfg.MODEL.DYHEAD.NUM_CLASSES - 1 num_tokens = cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS channels = cfg.MODEL.DYHEAD.CHANNELS if cfg.MODEL.DYHEAD.USE_GN: bn_type = ['gn', cfg.MODEL.GROUP_NORM.NUM_GROUPS] elif cfg.MODEL.DYHEAD.USE_NSYNCBN: bn_type = 'nsbn' elif cfg.MODEL.DYHEAD.USE_SYNCBN: bn_type = 'sbn' else: bn_type = None use_dyrelu = cfg.MODEL.DYHEAD.USE_DYRELU use_dyfuse = cfg.MODEL.DYHEAD.USE_DYFUSE use_deform = cfg.MODEL.DYHEAD.USE_DFCONV if cfg.MODEL.DYHEAD.CONV_FUNC: conv_func = lambda i, o, s: eval(cfg.MODEL.DYHEAD.CONV_FUNC)(i, o, s, bn_type=bn_type) else: conv_func = lambda i, o, s: Conv3x3Norm(i, o, s, deformable=use_deform, bn_type=bn_type) dyhead_tower = [] for i in range(cfg.MODEL.DYHEAD.NUM_CONVS): if cfg.MODEL.DYHEAD.FUSE_CONFIG.EARLY_FUSE_ON: # cross-modality fusion dyhead_tower.append( VLFuse(cfg) ) # self language path if i < cfg.MODEL.DYHEAD.NUM_CONVS - 1 or cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_FUSED_FEATURES_DOT_PRODUCT: # dyhead_tower.append( # BertEncoderLayer( # bert_cfg, # clamp_min_for_underflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MIN_FOR_UNDERFLOW, # clamp_max_for_overflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MAX_FOR_OVERFLOW) # ) if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "bert-base-uncased": dyhead_tower.append( BertEncoderLayer( lang_cfg, clamp_min_for_underflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MIN_FOR_UNDERFLOW, clamp_max_for_overflow=cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_BERTATTN_MAX_FOR_OVERFLOW) ) elif cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "clip": dyhead_tower.append( CLIPTransformerLayer(lang_cfg) ) else: raise NotImplementedError else: dyhead_tower.append( DummyLayer() ) # self vision path dyhead_tower.append( DyConv( in_channels if i == 0 else channels, channels, conv_func=conv_func, use_dyrelu=(use_dyrelu and in_channels == channels) if i == 0 else use_dyrelu, use_dyfuse=(use_dyfuse and in_channels == channels) if i == 0 else use_dyfuse, use_deform=(use_deform and in_channels == channels) if i == 0 else use_deform, ) ) self.add_module('dyhead_tower', nn.Sequential(*dyhead_tower)) self.cls_logits = nn.Conv2d(channels, num_anchors * num_classes, kernel_size=1) self.bbox_pred = nn.Conv2d(channels, num_anchors * 4, kernel_size=1) self.centerness = nn.Conv2d(channels, num_anchors * 1, kernel_size=1) # initialize the bias for focal loss prior_prob = cfg.MODEL.DYHEAD.PRIOR_PROB bias_value = -math.log((1 - prior_prob) / prior_prob) log_scale = self.cfg.MODEL.DYHEAD.LOG_SCALE # soft token head if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS: self.token_logits = nn.Conv2d(channels, num_anchors * num_tokens, kernel_size=1) # ABLATION # self.token_logits = nn.Conv2d(channels, num_anchors * num_tokens, kernel_size=1, bias=False) # self.bias = nn.Parameter(torch.zeros(channels), requires_grad=True) # self.bias0 = nn.Parameter(torch.Tensor([bias_value]), requires_grad=True) # contrastive alignment head if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS == False contrastive_hdim = cfg.MODEL.DYHEAD.FUSE_CONFIG.CONTRASTIVE_HIDDEN_DIM self.contrastive_align_projection_image = nn.Conv2d(channels, num_anchors * contrastive_hdim, kernel_size=1) self.contrastive_align_projection_text = nn.Linear(channels, contrastive_hdim, bias=True) self.log_scale = nn.Parameter(torch.Tensor([log_scale]), requires_grad=True) # dot product soft token head if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS == False self.dot_product_projection_image = nn.Identity() self.dot_product_projection_text = nn.Linear(self.cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, num_anchors * channels, bias=True) self.log_scale = nn.Parameter(torch.Tensor([log_scale]), requires_grad=True) # DEBUG # self.bias = nn.Parameter(torch.zeros(channels), requires_grad=True) self.bias_lang = nn.Parameter(torch.zeros(self.cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM), requires_grad=True) self.bias0 = nn.Parameter(torch.Tensor([bias_value]), requires_grad=True) # initialization for modules in [self.cls_logits, self.bbox_pred, self.centerness]: for l in modules.modules(): if isinstance(l, nn.Conv2d): torch.nn.init.normal_(l.weight, std=0.01) torch.nn.init.constant_(l.bias, 0) self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)]) torch.nn.init.constant_(self.cls_logits.bias, bias_value) # if use soft token loss if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS: for modules in [self.token_logits]: for l in modules.modules(): if isinstance(l, nn.Conv2d): torch.nn.init.normal_(l.weight, std=0.01) torch.nn.init.constant_(l.bias, 0) torch.nn.init.constant_(self.token_logits.bias, bias_value) # print(torch.norm(self.token_logits.weight)) # if use contrastive loss if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: for modules in [self.contrastive_align_projection_image]: for l in modules.modules(): if isinstance(l, nn.Conv2d): torch.nn.init.normal_(l.weight, std=0.01) torch.nn.init.constant_(l.bias, 0) # if use dot product token loss if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: for modules in [self.dot_product_projection_image]: for l in modules.modules(): if isinstance(l, nn.Conv2d): torch.nn.init.normal_(l.weight, std=0.01) torch.nn.init.constant_(l.bias, bias_value) if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS: if cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE == "clip": lang_cfg = BertConfig.from_pretrained("bert-base-uncased") lang_cfg.hidden_size = cfg.MODEL.CLIP.WIDTH lang_cfg.vocab_size = cfg.MODEL.CLIP.VOCAB_SIZE self.mlm_head = BertLMPredictionHead( lang_cfg ) #nn.Linear(hidden_size, config.vocab_size, bias=False) def forward(self, x, language_dict_features=None, embedding=None, swint_feature_c4=None): logits = [] bbox_reg = [] centerness = [] feat_inputs = {"visual": x, "lang": language_dict_features} dyhead_tower = self.dyhead_tower(feat_inputs) # soft token t_logits = None if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS: t_logits = [] if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_FUSED_FEATURES_DOT_PRODUCT: embedding = dyhead_tower["lang"]["hidden"] # MLM loss if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS: mlm_logits = self.mlm_head(embedding) else: mlm_logits = None # contrastive contrastive_logits = None proj_tokens = None if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: contrastive_logits = [] # follow MDETR's way proj_tokens = F.normalize( self.contrastive_align_projection_text(embedding), p=2, dim=-1 ) # dot product soft token dot_product_logits = None dot_product_proj_tokens = None dot_product_proj_tokens_bias = None if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: dot_product_logits = [] # norm embedding = F.normalize(embedding, p=2, dim=-1) dot_product_proj_tokens = self.dot_product_projection_text(embedding / 2.0) # w/o norm # dot_product_proj_tokens = self.dot_product_projection_text(embedding / 28.0) dot_product_proj_tokens_bias = torch.matmul(embedding, self.bias_lang) + self.bias0 # shallow contrastive (original feature from image & text encoder) shallow_img_emb_feats = None shallow_text_emb = None if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS \ or self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS: shallow_img_emb_feats = [] shallow_text_emb = embedding # print([v.shape for v in x]) # shallow contrastive: use the feature from swint backbone if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS: for b, feature in enumerate(swint_feature_c4): # BF, CF, HF, WF = feat.shape # shallow_img_emb = permute_and_flatten(feat, BF, -1, CF, HF, WF) shallow_img_emb_feats.append(feature) fused_visual_features = None if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES: fused_visual_features = [] # use the feature from FPN for l, feature in enumerate(x): logits.append(self.cls_logits(dyhead_tower["visual"][l])) bbox_pred = self.scales[l](self.bbox_pred(dyhead_tower["visual"][l])) bbox_reg.append(bbox_pred) centerness.append(self.centerness(dyhead_tower["visual"][l])) if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS: t_logits.append(self.token_logits(dyhead_tower["visual"][l])) # ABLATION # b = self.bias.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # x = dyhead_tower["visual"][l] # B, C, H, W = x.shape # bias = b.repeat(B, 1, H, W) # t_logits.append(self.token_logits(dyhead_tower["visual"][l] + bias) + self.bias0) if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: x = dyhead_tower["visual"][l] B, _, H, W = x.shape C = proj_tokens.shape[2] proj_queries = self.contrastive_align_projection_image(dyhead_tower["visual"][l]) proj_queries = permute_and_flatten(proj_queries, B, -1, C, H, W) normalized_img_emb = F.normalize(proj_queries, p=2, dim=-1) normalized_text_emb = proj_tokens contrastive_logit = ( torch.matmul(normalized_img_emb, normalized_text_emb.transpose(-1, -2)) / self.log_scale.exp()) contrastive_logits.append(contrastive_logit) if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: x = dyhead_tower["visual"][l] if self.cfg.MODEL.RPN.RETURN_FUSED_FEATURES: fused_visual_features.append(x) B, C, H, W = x.shape # add bias (language) dot_product_proj_queries = self.dot_product_projection_image(x) dot_product_proj_queries = permute_and_flatten(dot_product_proj_queries, B, -1, C, H, W) A = dot_product_proj_queries.shape[1] bias = dot_product_proj_tokens_bias.unsqueeze(1).repeat(1, A, 1) dot_product_logit = (torch.matmul(dot_product_proj_queries, dot_product_proj_tokens.transpose(-1, -2)) / self.log_scale.exp()) + bias if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.CLAMP_DOT_PRODUCT: dot_product_logit = torch.clamp(dot_product_logit, max=50000) dot_product_logit = torch.clamp(dot_product_logit, min=-50000) dot_product_logits.append(dot_product_logit) if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS: feat = feature BF, CF, HF, WF = feat.shape shallow_img_emb = permute_and_flatten(feat, BF, -1, CF, HF, WF) shallow_img_emb_feats.append(shallow_img_emb) # no matter the feature is from backboone or from fpn, we use shallow_img_embs all the time if shallow_img_emb_feats is not None and shallow_text_emb is not None: # shallow_img_embs = torch.cat(shallow_img_embs, dim=1) proj_tokens = shallow_text_emb return logits, bbox_reg, centerness, t_logits, proj_tokens, contrastive_logits, dot_product_logits, mlm_logits, shallow_img_emb_feats, fused_visual_features class VLDyHeadModule(torch.nn.Module): def __init__(self, cfg): super(VLDyHeadModule, self).__init__() self.cfg = cfg self.head = VLDyHead(cfg) box_coder = BoxCoder(cfg) self.loss_evaluator = make_atss_loss_evaluator(cfg, box_coder) self.box_selector_train = make_atss_postprocessor(cfg, box_coder, is_train=True) self.box_selector_test = make_atss_postprocessor(cfg, box_coder, is_train=False) self.anchor_generator = make_anchor_generator_complex(cfg) self.lang_model = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE self.joint_embedding_size = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_SIZE self.joint_embedding_dropout = cfg.MODEL.DYHEAD.FUSE_CONFIG.JOINT_EMB_DROPOUT if self.lang_model in ["bert-base-uncased", "roberta-base", "clip"]: self.lang_dim = cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM else: self.lang_dim = 1024 if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: self.resizer = FeatureResizer( input_feat_size=self.lang_dim, output_feat_size=self.joint_embedding_size, dropout=self.joint_embedding_dropout ) if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER: self.tunable_linear = torch.nn.Linear(self.lang_dim, 1000, bias=False) self.tunable_linear.weight.data.fill_(0.0) def forward(self, images, features, targets=None, language_dict_features=None, positive_map=None, captions=None, swint_feature_c4=None ): if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: # resizer needed embedding = language_dict_features['embedded'] embedding = self.resizer(embedding) elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: # no resizer needed embedding = language_dict_features['embedded'] else: embedding = None if "masks" in language_dict_features: text_masks = language_dict_features["masks"] else: text_masks = None if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.ADD_LINEAR_LAYER: embedding = self.tunable_linear.weight[:embedding.size(1), :].unsqueeze(0) + embedding language_dict_features['embedded'] = embedding language_dict_features['hidden'] = self.tunable_linear.weight[:embedding.size(1), :].unsqueeze(0) + language_dict_features['hidden'] box_cls, box_regression, centerness, token_logits, \ proj_tokens, contrastive_logits, dot_product_logits, mlm_logits, shallow_img_emb_feats, fused_visual_features = self.head(features, language_dict_features, embedding, swint_feature_c4 ) anchors = self.anchor_generator(images, features) if self.training: return self._forward_train(box_cls, box_regression, centerness, targets, anchors, captions, positive_map, token_logits, proj_tokens, contrastive_logits, dot_product_logits, text_masks, mlm_logits = mlm_logits, mlm_labels = language_dict_features["mlm_labels"], shallow_img_emb_feats=shallow_img_emb_feats, fused_visual_features=fused_visual_features ) else: return self._forward_test(box_regression, centerness, anchors, box_cls, token_logits, dot_product_logits, positive_map, fused_visual_features=fused_visual_features ) def _forward_train(self, box_cls, box_regression, centerness, targets, anchors, captions=None, positive_map=None, token_logits=None, proj_tokens=None, contrastive_logits=None, dot_product_logits=None, text_masks=None, mlm_logits=None, mlm_labels=None, shallow_img_emb_feats=None, fused_visual_features=None ): loss_box_cls, loss_box_reg, loss_centerness, loss_token, loss_contrastive_align, loss_dot_product_token, loss_shallow_contrastive = self.loss_evaluator( box_cls, box_regression, centerness, targets, anchors, captions, positive_map, token_logits, proj_tokens, contrastive_logits, dot_product_logits, text_masks, shallow_img_emb_feats ) losses = { # "loss_cls": loss_box_cls, "loss_reg": loss_box_reg, "loss_centerness": loss_centerness } if mlm_labels is not None and mlm_logits is not None: losses["mlm_loss"] = nn.CrossEntropyLoss(ignore_index = -100)(mlm_logits.view(-1, mlm_logits.size(-1)), mlm_labels.view(-1)) * self.cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS_COEF if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CLASSIFICATION_LOSS: losses["loss_cls"] = loss_box_cls else: losses["loss_cls"] = 0.0 * loss_box_cls if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS: losses["loss_token"] = loss_token * self.cfg.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_LOSS_WEIGHT if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: losses["loss_contrastive_align"] = loss_contrastive_align * \ self.cfg.MODEL.DYHEAD.FUSE_CONFIG.CONTRASTIVE_ALIGN_LOSS_WEIGHT if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: losses["loss_dot_product_token"] = loss_dot_product_token * \ self.cfg.MODEL.DYHEAD.FUSE_CONFIG.DOT_PRODUCT_TOKEN_LOSS_WEIGHT if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS or \ self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS: losses["loss_shallow_contrastive"] = loss_shallow_contrastive * \ self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_CONTRASTIVE_LOSS_WEIGHT if self.cfg.MODEL.RPN_ONLY: return None, losses, None else: # Let's just use one image per batch assert (box_regression[0].shape[0]) == 1 positive_map_label_to_token = create_positive_map_label_to_token_from_positive_map(positive_map, plus=1) boxes = self.box_selector_train(box_regression, centerness, anchors, box_cls, token_logits, dot_product_logits, positive_map=positive_map_label_to_token ) train_boxes = [] for b, t in zip(boxes, targets): tb = t.copy_with_fields(["labels"]) tb.add_field("scores", torch.ones(tb.bbox.shape[0], dtype=torch.bool, device=tb.bbox.device)) train_boxes.append(cat_boxlist([b, tb])) return train_boxes, losses, fused_visual_features def _forward_test(self, box_regression, centerness, anchors, box_cls=None, token_logits=None, dot_product_logits=None, positive_map=None, fused_visual_features=None ): boxes = self.box_selector_test(box_regression, centerness, anchors, box_cls, token_logits, dot_product_logits, positive_map, ) return boxes, {}, fused_visual_features