# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from collections import OrderedDict from typing import Sequence import torch from mmengine.model import BaseModel from torch import nn try: from transformers import AutoTokenizer, BertConfig from transformers import BertModel as HFBertModel except ImportError: AutoTokenizer = None HFBertModel = None from mmdet.registry import MODELS def generate_masks_with_special_tokens_and_transfer_map( tokenized, special_tokens_list): """Generate attention mask between each pair of special tokens. Only token pairs in between two special tokens are attended to and thus the attention mask for these pairs is positive. Args: input_ids (torch.Tensor): input ids. Shape: [bs, num_token] special_tokens_mask (list): special tokens mask. Returns: Tuple(Tensor, Tensor): - attention_mask is the attention mask between each tokens. Only token pairs in between two special tokens are positive. Shape: [bs, num_token, num_token]. - position_ids is the position id of tokens within each valid sentence. The id starts from 0 whenenver a special token is encountered. Shape: [bs, num_token] """ input_ids = tokenized['input_ids'] bs, num_token = input_ids.shape # special_tokens_mask: # bs, num_token. 1 for special tokens. 0 for normal tokens special_tokens_mask = torch.zeros((bs, num_token), device=input_ids.device).bool() for special_token in special_tokens_list: special_tokens_mask |= input_ids == special_token # idxs: each row is a list of indices of special tokens idxs = torch.nonzero(special_tokens_mask) # generate attention mask and positional ids attention_mask = ( torch.eye(num_token, device=input_ids.device).bool().unsqueeze(0).repeat( bs, 1, 1)) position_ids = torch.zeros((bs, num_token), device=input_ids.device) previous_col = 0 for i in range(idxs.shape[0]): row, col = idxs[i] if (col == 0) or (col == num_token - 1): attention_mask[row, col, col] = True position_ids[row, col] = 0 else: attention_mask[row, previous_col + 1:col + 1, previous_col + 1:col + 1] = True position_ids[row, previous_col + 1:col + 1] = torch.arange( 0, col - previous_col, device=input_ids.device) previous_col = col return attention_mask, position_ids.to(torch.long) @MODELS.register_module() class BertModel(BaseModel): """BERT model for language embedding only encoder. Args: name (str, optional): name of the pretrained BERT model from HuggingFace. Defaults to bert-base-uncased. max_tokens (int, optional): maximum number of tokens to be used for BERT. Defaults to 256. pad_to_max (bool, optional): whether to pad the tokens to max_tokens. Defaults to True. use_sub_sentence_represent (bool, optional): whether to use sub sentence represent introduced in `Grounding DINO `. Defaults to False. special_tokens_list (list, optional): special tokens used to split subsentence. It cannot be None when `use_sub_sentence_represent` is True. Defaults to None. add_pooling_layer (bool, optional): whether to adding pooling layer in bert encoder. Defaults to False. num_layers_of_embedded (int, optional): number of layers of the embedded model. Defaults to 1. use_checkpoint (bool, optional): whether to use gradient checkpointing. Defaults to False. """ def __init__(self, name: str = 'bert-base-uncased', max_tokens: int = 256, pad_to_max: bool = True, use_sub_sentence_represent: bool = False, special_tokens_list: list = None, add_pooling_layer: bool = False, num_layers_of_embedded: int = 1, use_checkpoint: bool = False, **kwargs) -> None: super().__init__(**kwargs) self.max_tokens = max_tokens self.pad_to_max = pad_to_max if AutoTokenizer is None: raise RuntimeError( 'transformers is not installed, please install it by: ' 'pip install transformers.') self.tokenizer = AutoTokenizer.from_pretrained(name) self.language_backbone = nn.Sequential( OrderedDict([('body', BertEncoder( name, add_pooling_layer=add_pooling_layer, num_layers_of_embedded=num_layers_of_embedded, use_checkpoint=use_checkpoint))])) self.use_sub_sentence_represent = use_sub_sentence_represent if self.use_sub_sentence_represent: assert special_tokens_list is not None, \ 'special_tokens should not be None \ if use_sub_sentence_represent is True' self.special_tokens = self.tokenizer.convert_tokens_to_ids( special_tokens_list) def forward(self, captions: Sequence[str], **kwargs) -> dict: """Forward function.""" device = next(self.language_backbone.parameters()).device tokenized = self.tokenizer.batch_encode_plus( captions, max_length=self.max_tokens, padding='max_length' if self.pad_to_max else 'longest', return_special_tokens_mask=True, return_tensors='pt', truncation=True).to(device) input_ids = tokenized.input_ids if self.use_sub_sentence_represent: attention_mask, position_ids = \ generate_masks_with_special_tokens_and_transfer_map( tokenized, self.special_tokens) token_type_ids = tokenized['token_type_ids'] else: attention_mask = tokenized.attention_mask position_ids = None token_type_ids = None tokenizer_input = { 'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids, 'token_type_ids': token_type_ids } language_dict_features = self.language_backbone(tokenizer_input) if self.use_sub_sentence_represent: language_dict_features['position_ids'] = position_ids language_dict_features[ 'text_token_mask'] = tokenized.attention_mask.bool() return language_dict_features class BertEncoder(nn.Module): """BERT encoder for language embedding. Args: name (str): name of the pretrained BERT model from HuggingFace. Defaults to bert-base-uncased. add_pooling_layer (bool): whether to add a pooling layer. num_layers_of_embedded (int): number of layers of the embedded model. Defaults to 1. use_checkpoint (bool): whether to use gradient checkpointing. Defaults to False. """ def __init__(self, name: str, add_pooling_layer: bool = False, num_layers_of_embedded: int = 1, use_checkpoint: bool = False): super().__init__() if BertConfig is None: raise RuntimeError( 'transformers is not installed, please install it by: ' 'pip install transformers.') config = BertConfig.from_pretrained(name) config.gradient_checkpointing = use_checkpoint # only encoder self.model = HFBertModel.from_pretrained( name, add_pooling_layer=add_pooling_layer, config=config) self.language_dim = config.hidden_size self.num_layers_of_embedded = num_layers_of_embedded def forward(self, x) -> dict: mask = x['attention_mask'] outputs = self.model( input_ids=x['input_ids'], attention_mask=mask, position_ids=x['position_ids'], token_type_ids=x['token_type_ids'], output_hidden_states=True, ) # outputs has 13 layers, 1 input layer and 12 hidden layers encoded_layers = outputs.hidden_states[1:] features = torch.stack(encoded_layers[-self.num_layers_of_embedded:], 1).mean(1) # language embedding has shape [len(phrase), seq_len, language_dim] features = features / self.num_layers_of_embedded if mask.dim() == 2: embedded = features * mask.unsqueeze(-1).float() else: embedded = features results = { 'embedded': embedded, 'masks': mask, 'hidden': encoded_layers[-1] } return results