|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
from typing import Optional |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
from peft import LoraConfig, get_peft_model |
|
from timm.models.layers import DropPath |
|
from torch import nn |
|
from transformers import GenerationConfig |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import logging |
|
|
|
from .configuration_internvl import InternVLConfig |
|
from .modeling_intern_vit import (InternVisionEmbeddings, InternVisionEncoder, |
|
InternVisionModel) |
|
from .modeling_qllama import LlamaForCausalLM, _expand_mask, _make_causal_mask |
|
|
|
try: |
|
from .flash_attention import FlashAttention |
|
except: |
|
print('FlashAttention is not installed.') |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class InternVLPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = InternVLConfig |
|
base_model_prefix = 'internvl' |
|
supports_gradient_checkpointing = True |
|
_keys_to_ignore_on_load_missing = [ |
|
r'position_ids', |
|
] |
|
_no_split_modules = ['InternAttention', 'LlamaDecoderLayer', 'LlamaForCausalLM'] |
|
_skip_keys_device_placement = 'past_key_values' |
|
_keep_in_fp32_modules = ['wo'] |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
factor = self.config.initializer_range |
|
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=factor) |
|
if hasattr(module, 'bias') and module.bias is not None: |
|
module.bias.data.zero_() |
|
if isinstance(module, InternVisionEmbeddings): |
|
if hasattr(self.config, 'vision_config'): |
|
factor = self.config.vision_config.initializer_range |
|
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor) |
|
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor) |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
elif isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if isinstance(module, InternVisionModel): |
|
module.gradient_checkpointing = value |
|
if isinstance(module, InternVisionEncoder): |
|
module.gradient_checkpointing = value |
|
|
|
|
|
class CrossAttention(nn.Module): |
|
def __init__( |
|
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., |
|
proj_drop=0., attn_head_dim=None, out_dim=None): |
|
super().__init__() |
|
if out_dim is None: |
|
out_dim = dim |
|
self.num_heads = num_heads |
|
head_dim = dim // num_heads |
|
if attn_head_dim is not None: |
|
head_dim = attn_head_dim |
|
all_head_dim = head_dim * self.num_heads |
|
self.scale = qk_scale or head_dim ** -0.5 |
|
assert all_head_dim == dim |
|
|
|
self.q = nn.Linear(dim, all_head_dim, bias=False) |
|
self.k = nn.Linear(dim, all_head_dim, bias=False) |
|
self.v = nn.Linear(dim, all_head_dim, bias=False) |
|
|
|
if qkv_bias: |
|
self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) |
|
self.k_bias = nn.Parameter(torch.zeros(all_head_dim)) |
|
self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) |
|
else: |
|
self.q_bias = None |
|
self.k_bias = None |
|
self.v_bias = None |
|
|
|
self.attn_drop = nn.Dropout(attn_drop) |
|
self.proj = nn.Linear(all_head_dim, out_dim) |
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
def forward(self, x, k=None, v=None): |
|
B, N, C = x.shape |
|
N_k = k.shape[1] |
|
N_v = v.shape[1] |
|
|
|
q_bias, k_bias, v_bias = None, None, None |
|
if self.q_bias is not None: |
|
q_bias = self.q_bias |
|
k_bias = self.k_bias |
|
v_bias = self.v_bias |
|
|
|
q = F.linear(input=x, weight=self.q.weight, bias=q_bias) |
|
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) |
|
|
|
k = F.linear(input=k, weight=self.k.weight, bias=k_bias) |
|
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) |
|
|
|
v = F.linear(input=v, weight=self.v.weight, bias=v_bias) |
|
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) |
|
|
|
q = q * self.scale |
|
attn = (q @ k.transpose(-2, -1)) |
|
|
|
attn = attn.softmax(dim=-1) |
|
attn = self.attn_drop(attn) |
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, -1) |
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
|
|
return x |
|
|
|
|
|
class AttentiveBlock(nn.Module): |
|
|
|
def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., |
|
drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None): |
|
super().__init__() |
|
|
|
self.norm1_q = norm_layer(dim) |
|
self.norm1_k = norm_layer(dim) |
|
self.norm1_v = norm_layer(dim) |
|
self.cross_attn = CrossAttention( |
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, |
|
proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim) |
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
|
def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None): |
|
x_q = self.norm1_q(x_q + pos_q) |
|
x_k = self.norm1_k(x_kv + pos_k) |
|
x_v = self.norm1_v(x_kv) |
|
x = self.cross_attn(x_q, k=x_k, v=x_v) |
|
|
|
return x |
|
|
|
|
|
class AttentionPoolingBlock(AttentiveBlock): |
|
|
|
def forward(self, x): |
|
x_q = x.mean(1, keepdim=True) |
|
x_kv, pos_q, pos_k = x, 0, 0 |
|
x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None) |
|
x = x.squeeze(1) |
|
return x |
|
|
|
|
|
class InternVLModel(InternVLPreTrainedModel): |
|
config_class = InternVLConfig |
|
main_input_name = 'pixel_values' |
|
|
|
def __init__(self, config: InternVLConfig): |
|
super().__init__(config) |
|
|
|
text_hidden_size = config.qllama_config.hidden_size |
|
vision_hidden_size = config.vision_config.hidden_size |
|
clip_embed_dim = config.clip_embed_dim |
|
attn_pool_num_heads = config.attn_pool_num_heads |
|
config.qllama_config.num_query_token = config.num_query_token |
|
self.num_query_token = config.num_query_token |
|
self.label_smoothing = config.label_smoothing |
|
|
|
self.vision_model = InternVisionModel(config.vision_config) |
|
self.qllama = LlamaForCausalLM(config.qllama_config) |
|
self.query_tokens = nn.Parameter( |
|
torch.zeros(1, config.num_query_token, text_hidden_size) |
|
) |
|
|
|
self.text_projection = nn.Parameter(torch.empty(text_hidden_size, clip_embed_dim)) |
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
self.clip_projector = AttentionPoolingBlock( |
|
dim=vision_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None, |
|
drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim) |
|
self.clip_projector2 = AttentionPoolingBlock( |
|
dim=text_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None, |
|
drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim) |
|
self.itm_head = nn.Linear(text_hidden_size, 2) |
|
self.gradient_checkpointing = True |
|
|
|
|
|
|
|
|
|
if config.use_backbone_lora: |
|
self.wrap_backbone_lora(r=config.use_backbone_lora) |
|
if config.use_qllama_lora: |
|
self.wrap_qllama_lora(r=config.use_qllama_lora) |
|
if config.force_image_size: |
|
self.vision_model.resize_pos_embeddings( |
|
old_size=config.vision_config.image_size, |
|
new_size=config.force_image_size, |
|
patch_size=config.vision_config.patch_size |
|
) |
|
|
|
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): |
|
lora_config = LoraConfig( |
|
r=r, |
|
target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'], |
|
lora_alpha=lora_alpha, |
|
lora_dropout=lora_dropout, |
|
) |
|
self.vision_model = get_peft_model(self.vision_model, lora_config) |
|
self.vision_model.print_trainable_parameters() |
|
|
|
def wrap_qllama_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): |
|
lora_config = LoraConfig( |
|
r=r, |
|
target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj', |
|
'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'], |
|
lora_alpha=lora_alpha, |
|
lora_dropout=lora_dropout, |
|
) |
|
self.qllama = get_peft_model(self.qllama, lora_config) |
|
self.qllama.print_trainable_parameters() |
|
|
|
def get_input_embeddings(self): |
|
return self.qllama.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value): |
|
self.qllama.set_input_embeddings(value) |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.qllama.set_output_embeddings(new_embeddings) |
|
|
|
def get_output_embeddings(self) -> nn.Module: |
|
return self.qllama.get_output_embeddings() |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
input_ids: torch.FloatTensor, |
|
attention_mask: torch.LongTensor, |
|
generation_config: Optional[GenerationConfig] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**generate_kwargs, |
|
) -> torch.LongTensor: |
|
|
|
vision_outputs = self.vision_model( |
|
pixel_values=pixel_values, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict) |
|
image_embeds = vision_outputs[0] |
|
|
|
batch_size = image_embeds.shape[0] |
|
input_embeds = self.get_input_embeddings()(input_ids) |
|
query_tokens = self.query_tokens.repeat(batch_size, 1, 1) |
|
input_embeds = torch.cat([query_tokens, input_embeds], dim=1) |
|
image_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) |
|
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1) |
|
|
|
outputs = self.qllama.generate( |
|
inputs_embeds=input_embeds, |
|
attention_mask=attention_mask, |
|
vision_hidden_states=image_embeds, |
|
generation_config=generation_config, |
|
use_zero_attention_mask=True, |
|
**generate_kwargs, |
|
) |
|
|
|
return outputs |
|
|
|
def get_text_features( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
r""" |
|
Returns: |
|
text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`): |
|
The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that |
|
contains the language model logits, the past key values and the hidden states if |
|
`output_hidden_states=True`. |
|
```""" |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
input_embeds = self.get_input_embeddings()(input_ids) |
|
attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to( |
|
input_embeds.device) |
|
attention_mask += _make_causal_mask( |
|
(attention_mask.shape[0], attention_mask.shape[2]), |
|
input_embeds.dtype, |
|
device=input_embeds.device |
|
) |
|
if type(self.qllama.model) == LlamaForCausalLM: |
|
outputs = self.qllama.model.model.forward_train( |
|
inputs_embeds=input_embeds, |
|
vision_hidden_states=None, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
).last_hidden_state |
|
else: |
|
outputs = self.qllama.model.forward_train( |
|
inputs_embeds=input_embeds, |
|
vision_hidden_states=None, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
).last_hidden_state |
|
return outputs |
|
|
|
def get_image_features( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
vision_outputs = self.vision_model( |
|
pixel_values=pixel_values, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict) |
|
image_embeds = vision_outputs[0] |
|
backbone_embeds = image_embeds |
|
|
|
batch_size = image_embeds.shape[0] |
|
input_embeds = self.query_tokens.repeat(batch_size, 1, 1) |
|
|
|
attention_mask = torch.ones(input_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) |
|
attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to( |
|
input_embeds.device) |
|
if type(self.qllama.model) == LlamaForCausalLM: |
|
outputs = self.qllama.model.model.forward_train( |
|
inputs_embeds=input_embeds, |
|
vision_hidden_states=image_embeds, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
).last_hidden_state |
|
else: |
|
outputs = self.qllama.model.forward_train( |
|
inputs_embeds=input_embeds, |
|
vision_hidden_states=image_embeds, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
).last_hidden_state |
|
return backbone_embeds, outputs |
|
|
|
def encode_image(self, image, mode): |
|
if mode == 'InternVL-C': |
|
vision_outputs = self.vision_model( |
|
pixel_values=image, |
|
output_hidden_states=False, |
|
return_dict=True) |
|
image_embeds = vision_outputs[0] |
|
image_embeds = self.clip_projector(image_embeds) |
|
elif mode == 'InternVL-G': |
|
backbone_embeds, image_embeds = self.get_image_features( |
|
pixel_values=image, |
|
output_hidden_states=False, |
|
return_dict=True, |
|
) |
|
backbone_embeds = self.clip_projector(backbone_embeds) |
|
image_embeds = self.clip_projector2(image_embeds) |
|
|
|
backbone_embeds = backbone_embeds / backbone_embeds.norm(dim=1, keepdim=True) |
|
image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True) |
|
image_embeds = image_embeds + backbone_embeds |
|
else: |
|
raise NotImplementedError |
|
return image_embeds |
|
|
|
def encode_text(self, text): |
|
attention_mask = text > 0 |
|
text_embeds = self.get_text_features( |
|
input_ids=text, |
|
attention_mask=attention_mask, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
return_dict=True, |
|
) |
|
text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1] |
|
text_embeds = text_embeds @ self.text_projection |
|
return text_embeds |
|
|
|
def forward(self, image, text, mode='InternVL-C'): |
|
assert mode in ['InternVL-C', 'InternVL-G'], 'mode must be InternVL-C or InternVL-G' |
|
image_features = self.encode_image(image, mode) |
|
text_features = self.encode_text(text) |
|
|
|
|
|
image_features = image_features / image_features.norm(dim=1, keepdim=True) |
|
text_features = text_features / text_features.norm(dim=1, keepdim=True) |
|
|
|
|
|
logit_scale = self.logit_scale.exp() |
|
logits_per_image = logit_scale * image_features @ text_features.t() |
|
logits_per_text = logits_per_image.t() |
|
|
|
return logits_per_image, logits_per_text |
|
|
|
|
|
class InternVL_C(InternVLModel): |
|
|
|
def encode_image(self, image): |
|
vision_outputs = self.vision_model( |
|
pixel_values=image, |
|
output_hidden_states=False, |
|
return_dict=True) |
|
image_embeds = vision_outputs[0] |
|
image_embeds = self.clip_projector(image_embeds) |
|
return image_embeds |
|
|
|
def encode_text(self, text): |
|
attention_mask = text > 0 |
|
text_embeds = self.get_text_features( |
|
input_ids=text, |
|
attention_mask=attention_mask, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
return_dict=True, |
|
) |
|
text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1] |
|
text_embeds = text_embeds @ self.text_projection |
|
return text_embeds |
|
|
|
def forward(self, image, text): |
|
image_features = self.encode_image(image) |
|
text_features = self.encode_text(text) |
|
|
|
|
|
image_features = image_features / image_features.norm(dim=1, keepdim=True) |
|
text_features = text_features / text_features.norm(dim=1, keepdim=True) |
|
|
|
|
|
logit_scale = self.logit_scale.exp() |
|
logits_per_image = logit_scale * image_features @ text_features.t() |
|
logits_per_text = logits_per_image.t() |
|
|
|
return logits_per_image, logits_per_text |
|
|
|
|
|
class InternVL_G(InternVLModel): |
|
|
|
def encode_image(self, image): |
|
backbone_embeds, image_embeds = self.get_image_features( |
|
pixel_values=image, |
|
output_hidden_states=False, |
|
return_dict=True, |
|
) |
|
backbone_embeds = self.clip_projector(backbone_embeds) |
|
image_embeds = self.clip_projector2(image_embeds) |
|
|
|
backbone_embeds = backbone_embeds / backbone_embeds.norm(dim=1, keepdim=True) |
|
image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True) |
|
image_embeds = image_embeds + backbone_embeds |
|
return image_embeds |
|
|
|
def encode_text(self, text): |
|
attention_mask = text > 0 |
|
text_embeds = self.get_text_features( |
|
input_ids=text, |
|
attention_mask=attention_mask, |
|
output_attentions=False, |
|
output_hidden_states=False, |
|
return_dict=True, |
|
) |
|
text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1] |
|
text_embeds = text_embeds @ self.text_projection |
|
return text_embeds |
|
|
|
def forward(self, image, text): |
|
image_features = self.encode_image(image) |
|
text_features = self.encode_text(text) |
|
|
|
|
|
image_features = image_features / image_features.norm(dim=1, keepdim=True) |
|
text_features = text_features / text_features.norm(dim=1, keepdim=True) |
|
|
|
|
|
logit_scale = self.logit_scale.exp() |
|
logits_per_image = logit_scale * image_features @ text_features.t() |
|
logits_per_text = logits_per_image.t() |
|
|
|
return logits_per_image, logits_per_text |
|
|