HPT / modeling_hformer.py
xwwu's picture
Upload folder using huggingface_hub
26c2f02 verified
import torch
torch.manual_seed(1024)
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_hformer import HformerConfig
from .qformer_src import BertConfig, BertLMHeadModel
from transformers import BertTokenizerFast as BertTokenizer
from .configuration_projector import ProjectorConfig
from .modeling_projector import ProjectorModel
import torch.nn.functional as F
from transformers.activations import ACT2FN
class LayerNorm(nn.LayerNorm):
def forward(self, x: torch.Tensor):
ret = super().forward(x)
return ret
class HformerModel(PreTrainedModel):
_auto_class = 'AutoModel'
config_class = HformerConfig
base_model_prefix = 'model'
supports_gradient_checkpointing = False
def __init__(self, config) -> None:
super().__init__(config)
self.gradient_checkpointing = False
vision_width = config.visual_hidden_size
num_query_token = config.num_query_token
bert = config.bert
llm_hidden_size = config.llm_hidden_size
cross_attention_freq = config.cross_attention_freq
qformer_pth = config.qformer_pth
encoder_config = BertConfig.from_pretrained(bert)
encoder_config.encoder_width = vision_width
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
encoder_config.num_hidden_layers = 12
Qformer = BertLMHeadModel.from_pretrained(
bert, config=encoder_config
)
remove_text = False
if remove_text:
Qformer.cls = None
Qformer.bert.embeddings.word_embeddings = None
Qformer.bert.embeddings.position_embeddings = None
for layer in Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
self.Qformer = Qformer
self.query_tokens = query_tokens
self.llm_proj = nn.Linear(encoder_config.hidden_size, llm_hidden_size, bias=config.bias)
self.ln_vision = LayerNorm(encoder_config.encoder_width)
self.ln_llava = LayerNorm(encoder_config.encoder_width)
tokenizer = BertTokenizer.from_pretrained(bert, truncation_side='right')
tokenizer.add_special_tokens({"bos_token": "[DEC]"})
self.Qformer.resize_token_embeddings(len(tokenizer))
if qformer_pth is not None:
pretrained_state_dict = torch.load(qformer_pth, map_location='cpu')['model']
print(f'Load Qformer from {qformer_pth}')
self.load_state_dict(pretrained_state_dict, strict=False)
print('Done.')
projector_config = ProjectorConfig(
visual_hidden_size = config.visual_hidden_size,
llm_hidden_size = config.llm_hidden_size,
projector_depth = 2)
self.connector = ProjectorModel(projector_config)
modules = [
nn.Linear(config.llm_hidden_size, config.llm_hidden_size//4, bias=False),
ACT2FN['gelu'],
nn.Linear(config.llm_hidden_size//4, config.llm_hidden_size, bias=False)
]
self.ffn = nn.Sequential(*modules)
def enable_input_require_grads(self):
def make_inputs_require_grad(module, input, output):
if isinstance(output, tuple):
output[0].requires_grad_(True)
output[1].requires_grad_(True)
else:
output.requires_grad_(True)
self.Qformer.register_forward_hook(make_inputs_require_grad)
self.llm_proj.register_forward_hook(make_inputs_require_grad)
self.ln_vision.register_forward_hook(make_inputs_require_grad)
self.connector.register_forward_hook(make_inputs_require_grad)
self.ffn.register_forward_hook(make_inputs_require_grad)
def _set_gradient_checkpointing(self, module, value=False):
pass
def forward(self, x_):
if self.gradient_checkpointing and self.training:
print('Not support gradient checkpointing')
x = self.ln_vision(x_)
query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=x,
return_dict=True,
)
q_feat = self.llm_proj(query_output.last_hidden_state)
mlp_outputs = self.connector(x_)
mlp_feat = mlp_outputs
int_feat = mlp_feat + q_feat.mean(dim=1)[:,None]
out = int_feat + self.ffn(int_feat)
return out