rawalkhirodkar's picture
Add initial commit
28c256d
raw
history blame
9.22 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from typing import Sequence
import torch
from mmengine.model import BaseModel
from torch import nn
try:
from transformers import AutoTokenizer, BertConfig
from transformers import BertModel as HFBertModel
except ImportError:
AutoTokenizer = None
HFBertModel = None
from mmdet.registry import MODELS
def generate_masks_with_special_tokens_and_transfer_map(
tokenized, special_tokens_list):
"""Generate attention mask between each pair of special tokens.
Only token pairs in between two special tokens are attended to
and thus the attention mask for these pairs is positive.
Args:
input_ids (torch.Tensor): input ids. Shape: [bs, num_token]
special_tokens_mask (list): special tokens mask.
Returns:
Tuple(Tensor, Tensor):
- attention_mask is the attention mask between each tokens.
Only token pairs in between two special tokens are positive.
Shape: [bs, num_token, num_token].
- position_ids is the position id of tokens within each valid sentence.
The id starts from 0 whenenver a special token is encountered.
Shape: [bs, num_token]
"""
input_ids = tokenized['input_ids']
bs, num_token = input_ids.shape
# special_tokens_mask:
# bs, num_token. 1 for special tokens. 0 for normal tokens
special_tokens_mask = torch.zeros((bs, num_token),
device=input_ids.device).bool()
for special_token in special_tokens_list:
special_tokens_mask |= input_ids == special_token
# idxs: each row is a list of indices of special tokens
idxs = torch.nonzero(special_tokens_mask)
# generate attention mask and positional ids
attention_mask = (
torch.eye(num_token,
device=input_ids.device).bool().unsqueeze(0).repeat(
bs, 1, 1))
position_ids = torch.zeros((bs, num_token), device=input_ids.device)
previous_col = 0
for i in range(idxs.shape[0]):
row, col = idxs[i]
if (col == 0) or (col == num_token - 1):
attention_mask[row, col, col] = True
position_ids[row, col] = 0
else:
attention_mask[row, previous_col + 1:col + 1,
previous_col + 1:col + 1] = True
position_ids[row, previous_col + 1:col + 1] = torch.arange(
0, col - previous_col, device=input_ids.device)
previous_col = col
return attention_mask, position_ids.to(torch.long)
@MODELS.register_module()
class BertModel(BaseModel):
"""BERT model for language embedding only encoder.
Args:
name (str, optional): name of the pretrained BERT model from
HuggingFace. Defaults to bert-base-uncased.
max_tokens (int, optional): maximum number of tokens to be
used for BERT. Defaults to 256.
pad_to_max (bool, optional): whether to pad the tokens to max_tokens.
Defaults to True.
use_sub_sentence_represent (bool, optional): whether to use sub
sentence represent introduced in `Grounding DINO
<https://arxiv.org/abs/2303.05499>`. Defaults to False.
special_tokens_list (list, optional): special tokens used to split
subsentence. It cannot be None when `use_sub_sentence_represent`
is True. Defaults to None.
add_pooling_layer (bool, optional): whether to adding pooling
layer in bert encoder. Defaults to False.
num_layers_of_embedded (int, optional): number of layers of
the embedded model. Defaults to 1.
use_checkpoint (bool, optional): whether to use gradient checkpointing.
Defaults to False.
"""
def __init__(self,
name: str = 'bert-base-uncased',
max_tokens: int = 256,
pad_to_max: bool = True,
use_sub_sentence_represent: bool = False,
special_tokens_list: list = None,
add_pooling_layer: bool = False,
num_layers_of_embedded: int = 1,
use_checkpoint: bool = False,
**kwargs) -> None:
super().__init__(**kwargs)
self.max_tokens = max_tokens
self.pad_to_max = pad_to_max
if AutoTokenizer is None:
raise RuntimeError(
'transformers is not installed, please install it by: '
'pip install transformers.')
self.tokenizer = AutoTokenizer.from_pretrained(name)
self.language_backbone = nn.Sequential(
OrderedDict([('body',
BertEncoder(
name,
add_pooling_layer=add_pooling_layer,
num_layers_of_embedded=num_layers_of_embedded,
use_checkpoint=use_checkpoint))]))
self.use_sub_sentence_represent = use_sub_sentence_represent
if self.use_sub_sentence_represent:
assert special_tokens_list is not None, \
'special_tokens should not be None \
if use_sub_sentence_represent is True'
self.special_tokens = self.tokenizer.convert_tokens_to_ids(
special_tokens_list)
def forward(self, captions: Sequence[str], **kwargs) -> dict:
"""Forward function."""
device = next(self.language_backbone.parameters()).device
tokenized = self.tokenizer.batch_encode_plus(
captions,
max_length=self.max_tokens,
padding='max_length' if self.pad_to_max else 'longest',
return_special_tokens_mask=True,
return_tensors='pt',
truncation=True).to(device)
input_ids = tokenized.input_ids
if self.use_sub_sentence_represent:
attention_mask, position_ids = \
generate_masks_with_special_tokens_and_transfer_map(
tokenized, self.special_tokens)
token_type_ids = tokenized['token_type_ids']
else:
attention_mask = tokenized.attention_mask
position_ids = None
token_type_ids = None
tokenizer_input = {
'input_ids': input_ids,
'attention_mask': attention_mask,
'position_ids': position_ids,
'token_type_ids': token_type_ids
}
language_dict_features = self.language_backbone(tokenizer_input)
if self.use_sub_sentence_represent:
language_dict_features['position_ids'] = position_ids
language_dict_features[
'text_token_mask'] = tokenized.attention_mask.bool()
return language_dict_features
class BertEncoder(nn.Module):
"""BERT encoder for language embedding.
Args:
name (str): name of the pretrained BERT model from HuggingFace.
Defaults to bert-base-uncased.
add_pooling_layer (bool): whether to add a pooling layer.
num_layers_of_embedded (int): number of layers of the embedded model.
Defaults to 1.
use_checkpoint (bool): whether to use gradient checkpointing.
Defaults to False.
"""
def __init__(self,
name: str,
add_pooling_layer: bool = False,
num_layers_of_embedded: int = 1,
use_checkpoint: bool = False):
super().__init__()
if BertConfig is None:
raise RuntimeError(
'transformers is not installed, please install it by: '
'pip install transformers.')
config = BertConfig.from_pretrained(name)
config.gradient_checkpointing = use_checkpoint
# only encoder
self.model = HFBertModel.from_pretrained(
name, add_pooling_layer=add_pooling_layer, config=config)
self.language_dim = config.hidden_size
self.num_layers_of_embedded = num_layers_of_embedded
def forward(self, x) -> dict:
mask = x['attention_mask']
outputs = self.model(
input_ids=x['input_ids'],
attention_mask=mask,
position_ids=x['position_ids'],
token_type_ids=x['token_type_ids'],
output_hidden_states=True,
)
# outputs has 13 layers, 1 input layer and 12 hidden layers
encoded_layers = outputs.hidden_states[1:]
features = torch.stack(encoded_layers[-self.num_layers_of_embedded:],
1).mean(1)
# language embedding has shape [len(phrase), seq_len, language_dim]
features = features / self.num_layers_of_embedded
if mask.dim() == 2:
embedded = features * mask.unsqueeze(-1).float()
else:
embedded = features
results = {
'embedded': embedded,
'masks': mask,
'hidden': encoded_layers[-1]
}
return results