cogcap / modeling_cogvlm.py
asabet's picture
Upload CogVLMForCausalLM
b5aa3ad verified
"""largely copy from llama and adapt for cogvlm"""
import warnings
import packaging.version
from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, Any
import math
import torch
import transformers
from torch import nn
from torch.nn import CrossEntropyLoss
from torchvision import transforms
from einops import rearrange
from torch.utils.checkpoint import checkpoint
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.utils.logging import get_logger
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from .configuration_cogvlm import CogVLMConfig
from .util import FastRotaryEmbedding
from .visual import EVA2CLIPModel
if TYPE_CHECKING:
from transformers.utils import ModelOutput
logger = get_logger(__name__)
LANGUAGE_TOKEN_TYPE = 0
VISION_TOKEN_TYPE = 1
TRANSFORMERS_ABOVE_441 = (
True
if packaging.version.parse(transformers.__version__)
>= packaging.version.parse("4.42.0")
else False
)
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return (self.weight * hidden_states).to(input_dtype)
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def get_expert_mask(token_type_ids: "torch.LongTensor(B, L)") -> "[torch.BoolTensor(B, L), torch.BoolTensor(B, L)]":
vision_token_mask = torch.zeros_like(token_type_ids, dtype=torch.bool)
vision_token_mask[:, :-1] = (token_type_ids[:, :-1] == VISION_TOKEN_TYPE) & (token_type_ids[:, 1:] == VISION_TOKEN_TYPE)
language_token_mask = ~vision_token_mask
return vision_token_mask, language_token_mask
class VisionExpertMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.language_mlp = MLP(config)
self.vision_mlp = MLP(config)
def forward(self, hidden_states: "torch.Tensor(B, L, D)", token_type_ids: "torch.LongTensor(B, L)"):
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
output[vision_token_mask] = self.vision_mlp(hidden_states[vision_token_mask])
output[language_token_mask] = self.language_mlp(hidden_states[language_token_mask])
return output
def attention_fn(
query_layer: "torch.tensor(B, H, L, HD)",
key_layer: "torch.tensor(B, H, L, HD)",
value_layer: "torch.tensor(B, H, L, HD)",
attention_mask: "torch.tensor(B, H, L, HD)",
*,
scaling_attention_score: bool = True,
attention_dropout: nn.Module = None
):
attention_mask_bool = (attention_mask == 0)
is_low_triangle = (attention_mask_bool == torch.ones_like(attention_mask_bool, dtype=torch.float).tril()).all()
is_full = (attention_mask_bool > 0).all()
if not (int(torch.__version__.split('.')[0]) >= 2):
warnings.warn("It's recommended to use torch2.0 or higher.")
if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle):
dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p
return torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer,
attn_mask=None,
dropout_p=dropout_p,
is_causal=not is_full
)
else:
if scaling_attention_score:
query_layer = query_layer / math.sqrt(query_layer.shape[-1])
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores + attention_mask
attention_scores = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(query_layer.dtype)
if attention_dropout is not None:
attention_scores = attention_dropout(attention_scores)
context_layer = torch.matmul(attention_scores, value_layer)
return context_layer
class VisionExpertAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.num_multi_query_heads = config.num_multi_query_heads
self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
self.stride = [self.num_attention_heads, self.num_multi_query_heads, self.num_multi_query_heads]
self.qkv_size = self.hidden_size + self.hidden_size_per_attention_head * self.num_multi_query_heads * 2
self.head_dim = self.hidden_size // self.num_attention_heads
self.max_position_embeddings = config.max_position_embeddings
self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False, base=500000)
self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.qkv_size, bias=True)
self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.qkv_size, bias=False)
self.language_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [B, L, H*HD] into a 4D tensor with size [B H L HD]."""
new_tensor_shape = tensor.size()[:-1] + \
(-1, # flexible for multi-query
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
token_type_ids: torch.LongTensor,
position_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
vision_token_mask, language_token_mask = get_expert_mask(token_type_ids)
shape = list(hidden_states.shape)
shape[-1] = self.qkv_size
mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
mixed_raw_layer[vision_token_mask] = self.vision_expert_query_key_value(hidden_states[vision_token_mask])
mixed_raw_layer[language_token_mask] = self.language_expert_query_key_value(hidden_states[language_token_mask])
# query_states, key_states, value_states = torch.split(mixed_raw_layer, self.hidden_size, dim=-1)
factor = mixed_raw_layer.size()[-1] // sum(self.stride)
query_states, key_states, value_states = torch.split(mixed_raw_layer, [factor * x for x in self.stride], dim=-1)
query_states = self._transpose_for_scores(query_states) # B, H, L, HD
key_states = self._transpose_for_scores(key_states) # B, H, L, HD
value_states = self._transpose_for_scores(value_states) # B, H, L, HD
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids=position_ids, max_seqlen=position_ids.max() + 1)
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
key_states = key_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads, -1, -1).contiguous().view(
bsz, self.num_attention_heads, *key_states.shape[2:])
value_states = value_states.unsqueeze(2).expand(-1, -1, self.num_attention_heads // self.num_multi_query_heads, -1,
-1).contiguous().view(bsz, self.num_attention_heads, *value_states.shape[2:])
context_layer = attention_fn(
query_layer=query_states, key_layer=key_states, value_layer=value_states, attention_mask=attention_mask,
scaling_attention_score=True, attention_dropout=None)
if context_layer.size() != (bsz, self.num_attention_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_attention_heads, q_len, self.head_dim)}, but is"
f" {context_layer.size()}"
)
context_layer = context_layer.transpose(1, 2).contiguous().reshape(bsz, q_len, self.hidden_size)
attn_output = torch.empty(context_layer.shape, dtype=hidden_states.dtype, device=hidden_states.device)
attn_output[vision_token_mask] = self.vision_expert_dense(context_layer[vision_token_mask])
attn_output[language_token_mask] = self.language_expert_dense(context_layer[language_token_mask])
if output_attentions:
warnings.warn("output_attentions is not implemented.")
return attn_output, None, past_key_value
class CogVLMDecoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = VisionExpertAttention(config=config)
self.mlp = VisionExpertMLP(config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
token_type_ids: torch.LongTensor,
position_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states, token_type_ids=token_type_ids)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs # type: ignore
class CogVLMPreTrainedModel(PreTrainedModel):
config_class = CogVLMConfig
base_model_prefix = "model"
supports_gradient_checkpointing = False
_no_split_modules = ["CogVLMDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def is_empty(images_list: Optional[List[List[torch.Tensor]]]):
if images_list is None or len(images_list) == 0:
return True
for image_list in images_list:
if len(image_list):
return False
return True
def build_position_ids(x: "torch.BoolTensor(B, L)", attention_mask: Optional["torch.BoolTensor(B, L)"] = None) -> "torch.LongTensor(B, L)":
if attention_mask is not None:
tmp = x.clone()
tmp[~(attention_mask.bool())] = -1
else:
tmp = x.clone()
# image boi eoi token as LANGUAGE_TOKEN_TYPE
is_boi_eoi = torch.zeros_like(x, dtype=torch.bool)
is_boi_eoi[:, 1:] |= (tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE)
is_boi_eoi[:, 0] |= (tmp[:, 0] == VISION_TOKEN_TYPE)
is_boi_eoi[:, :-1] |= (tmp[:, :-1] == VISION_TOKEN_TYPE) & (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE)
is_boi_eoi[:, -1] |= (tmp[:, -1] == VISION_TOKEN_TYPE)
tmp[is_boi_eoi] = LANGUAGE_TOKEN_TYPE
# final position ids
y = torch.zeros_like(x, dtype=torch.long)
y[:, 1:] = (tmp[:, 1:] == LANGUAGE_TOKEN_TYPE) | ((tmp[:, 1:] == VISION_TOKEN_TYPE) & (tmp[:, :-1] == LANGUAGE_TOKEN_TYPE))
y = y.cumsum(dim=-1)
return y
class CogVLMModel(CogVLMPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.padding_idx = 128002
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([CogVLMDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.vision = EVA2CLIPModel(config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def encode_images(self, images: List[List[torch.Tensor]]) -> torch.Tensor:
images_list, images = images, []
images = []
for image_list in images_list:
for image in image_list:
images.append(image)
images = torch.stack(images)
images_features = self.vision(images)
return images_features
def forward(
self,
input_ids: torch.LongTensor = None,
images: List[List[torch.Tensor]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""take care of image_encode, token_type_ids, position_ids and (attention_mask = None is fine)"""
if past_key_values is not None:
pass # generate mode with past_key_values. the image features are already mapped
else:
# not allow for inputs_embeds, because we want to process image feature
assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
if not is_empty(images): # multi-modality
assert token_type_ids is not None, f"multi-modality requires `token_type_ids`!"
assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
inputs_embeds = self.embed_tokens(input_ids)
images_features = self.encode_images(images)
images_features = rearrange(images_features, 'b n d -> (b n) d')
images_features = images_features.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
inputs_embeds = inputs_embeds.index_put([token_type_ids == VISION_TOKEN_TYPE], images_features)
else: # single-modality
if token_type_ids is None:
token_type_ids = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device) * LANGUAGE_TOKEN_TYPE
assert not (token_type_ids == VISION_TOKEN_TYPE).any(), f"{(token_type_ids == VISION_TOKEN_TYPE).sum()}"
inputs_embeds = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = build_position_ids(token_type_ids, attention_mask)
input_ids = None
return self.llm_forward(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def llm_forward(
self,
input_ids: torch.LongTensor = None,
token_type_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""largely copy from llama forward and adapt for cogvlm with `token_type_ids`"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
def custom(index):
def custom_forward(
hidden_states,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
):
layer = self.layers[index]
outputs = layer(
hidden_states,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
return outputs
return custom_forward
# layer_outputs = decoder_layer(
# hidden_states,
# token_type_ids=token_type_ids,
# attention_mask=attention_mask,
# position_ids=position_ids,
# past_key_value=past_key_value,
# output_attentions=output_attentions,
# use_cache=use_cache,
# )
layer_outputs = checkpoint(custom(idx),
hidden_states,
use_reentrant=False
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# noinspection PyMethodMayBeStatic
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def _history_to_prompt(signal_type, history, query):
if signal_type == 'base':
return query
elif signal_type == 'vqa':
answer_format = 'Short answer:'
elif signal_type == 'chat':
answer_format = 'Answer:'
else:
assert False, f"Unknown signal type {signal_type}"
prompt = ''
for i, (old_query, response) in enumerate(history):
prompt += 'Question: ' + old_query + " {} ".format(answer_format) + response + "\n"
prompt += 'Question: {} {}'.format(query, answer_format)
return prompt
class CogVLMForCausalLM(CogVLMPreTrainedModel):
_auto_class = "AutoModelForCausalLM"
def __init__(self, config):
super().__init__(config)
self.model = CogVLMModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
images: List[List[torch.Tensor]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
images=images,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def _prepare_attention_mask_for_generation(
self,
inputs: torch.Tensor,
pad_token_id: Optional[int],
eos_token_id: Optional[Union[int, List[int]]],
) -> torch.LongTensor:
return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) # type: ignore
def prepare_inputs_for_generation(
self, input_ids, token_type_ids, images=None, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
# build position_ids if needed
position_ids = kwargs.get("position_ids", None)
if position_ids is None:
position_ids = build_position_ids(token_type_ids, attention_mask)
if past_key_values:
input_ids = input_ids[:, -1:]
token_type_ids = token_type_ids[:, -1:]
position_ids = position_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"token_type_ids": token_type_ids,
"images": images,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def _update_model_kwargs_for_generation(
self,
outputs: "ModelOutput",
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]:
# update past_key_values
if TRANSFORMERS_ABOVE_441:
cache_name, cache = self._extract_past_from_model_output(outputs)
model_kwargs[cache_name] = cache
else:
model_kwargs["past_key_values"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype, device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
if not is_encoder_decoder:
# update attention mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
)
else:
# update decoder attention mask
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
model_kwargs["decoder_attention_mask"] = torch.cat(
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
dim=-1,
)
return model_kwargs
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
def build_conversation_input_ids(
self,
tokenizer: "PreTrainedTokenizer",
*,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
images: Optional[List["PIL.Image"]] = None,
template_version: Optional[Literal["base", "chat", "vqa"]] = None,
answer: str = None,
):
image_size: int = self.config.vision_config['image_size']
patch_size: int = self.config.vision_config['patch_size']
template_version = template_version or self.config.template_version
assert images is None or len(images) <= 1, f"not support multi images by now."
history = history or []
text = _history_to_prompt(template_version, history, query)
input_ids = [tokenizer.bos_token_id]
token_type_ids = [LANGUAGE_TOKEN_TYPE]
if images is not None and len(images) == 1:
# vision
transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]
)
images = [transform(images[0])]
# language
vision_token_num = (image_size // patch_size // 2) * (image_size // patch_size // 2) + 2
tokenizer.pad_token_id = 128002 # llama3 adapt for cogvlm
input_ids += [tokenizer.pad_token_id] * vision_token_num
token_type_ids += [VISION_TOKEN_TYPE] * vision_token_num
text_ids = tokenizer.encode(text, add_special_tokens=False)
if answer is not None:
answer_ids = tokenizer.encode(answer, add_special_tokens=False)
answer_ids += [tokenizer.eos_token_id]
text_ids += answer_ids
input_ids += text_ids
token_type_ids += [LANGUAGE_TOKEN_TYPE] * len(text_ids)
attention_mask = [1] * len(input_ids)
if answer is not None:
labels = [-100 for _ in range(len(input_ids) - len(answer_ids))] + answer_ids
labels = torch.tensor(labels, dtype=torch.long)
else:
labels = None
return {
'input_ids': torch.tensor(input_ids, dtype=torch.long),
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
'images': images,
'labels': labels,
}