Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2024 the HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" PyTorch Llava-NeXT model.""" | |
from dataclasses import dataclass | |
from typing import List, Optional, Tuple, Union | |
import torch | |
import torch.utils.checkpoint | |
from torch import nn | |
import numpy as np | |
from transformers import PreTrainedModel | |
from transformers.activations import ACT2FN | |
from transformers.cache_utils import Cache | |
from transformers.image_processing_utils import select_best_resolution | |
from transformers.modeling_outputs import ModelOutput | |
from transformers.configuration_utils import PretrainedConfig | |
from transformers.utils import ( | |
add_start_docstrings, | |
add_start_docstrings_to_model_forward, | |
logging, | |
replace_return_docstrings, | |
) | |
from transformers.models.auto import AutoModel, AutoModelForCausalLM | |
# from .configuration_llava_next import LlavaNextConfig | |
from transformers.models.auto import CONFIG_MAPPING | |
logger = logging.get_logger(__name__) | |
class LlavaNextConfig(PretrainedConfig): | |
r""" | |
This is the configuration class to store the configuration of a [`LlavaNextForConditionalGeneration`]. It is used to instantiate an | |
Llava-NeXT model according to the specified arguments, defining the model architecture. Instantiating a configuration | |
with the defaults will yield a similar configuration to that of the [llava-hf/llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) | |
model. | |
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the | |
documentation from [`PretrainedConfig`] for more information. | |
Args: | |
vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): | |
The config object or dictionary of the vision backbone. | |
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): | |
The config object or dictionary of the text backbone. | |
ignore_index (`int`, *optional*, defaults to -100): | |
The ignore index for the loss function. | |
image_token_index (`int`, *optional*, defaults to 32000): | |
The image token index to encode the image prompt. | |
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): | |
The activation function used by the multimodal projector. | |
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): | |
The feature selection strategy used to select the vision feature from the vision backbone. | |
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. | |
If `"full"`, the full vision features are used. | |
vision_feature_layer (`int`, *optional*, defaults to -2): | |
The index of the layer to select the vision feature. | |
image_grid_pinpoints (`List`, *optional*, defaults to `[[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]`): | |
A list of possible resolutions to use for processing high resolution images. Each item in the list should be a tuple or list | |
of the form `(height, width)`. | |
Example: | |
```python | |
>>> from transformers import LlavaNextForConditionalGeneration, LlavaNextConfig, CLIPVisionConfig, LlamaConfig | |
>>> # Initializing a CLIP-vision config | |
>>> vision_config = CLIPVisionConfig() | |
>>> # Initializing a Llama config | |
>>> text_config = LlamaConfig() | |
>>> # Initializing a Llava-Next llava-hf/llava-v1.6-mistral-7b-hf style configuration | |
>>> configuration = LlavaNextConfig(vision_config, text_config) | |
>>> # Initializing a model from the llava-hf/llava-v1.6-mistral-7b-hf style configuration | |
>>> model = LlavaNextForConditionalGeneration(configuration) | |
>>> # Accessing the model configuration | |
>>> configuration = model.config | |
```""" | |
model_type = "llava_next" | |
is_composition = False | |
def __init__( | |
self, | |
vision_config=None, | |
text_config=None, | |
ignore_index=-100, | |
image_token_index=32000, | |
projector_hidden_act="gelu", | |
vision_feature_select_strategy="default", | |
vision_feature_layer=-2, | |
image_grid_pinpoints=None, | |
**kwargs, | |
): | |
self.ignore_index = ignore_index | |
self.image_token_index = image_token_index | |
self.projector_hidden_act = projector_hidden_act | |
if vision_feature_select_strategy not in ["default", "full"]: | |
raise ValueError( | |
"vision_feature_select_strategy should be one of 'default', 'full'." | |
f"Got: {vision_feature_select_strategy}" | |
) | |
self.vision_feature_select_strategy = vision_feature_select_strategy | |
self.vision_feature_layer = vision_feature_layer | |
image_grid_pinpoints = ( | |
image_grid_pinpoints | |
if image_grid_pinpoints is not None | |
else [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]] | |
) | |
self.image_grid_pinpoints = image_grid_pinpoints | |
if isinstance(vision_config, dict): | |
vision_config["model_type"] = ( | |
vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model" | |
) | |
vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) | |
elif vision_config is None: | |
vision_config = CONFIG_MAPPING["clip_vision_model"]( | |
intermediate_size=4096, | |
hidden_size=1024, | |
patch_size=14, | |
image_size=336, | |
num_hidden_layers=24, | |
num_attention_heads=16, | |
vocab_size=32000, | |
projection_dim=768, | |
) | |
self.vision_config = vision_config | |
if isinstance(text_config, dict): | |
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" | |
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) | |
elif text_config is None: | |
text_config = CONFIG_MAPPING["llama"]() | |
self.text_config = text_config | |
super().__init__(**kwargs) | |
_CONFIG_FOR_DOC = "LlavaNextConfig" | |
LLAVA_NEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [ | |
"llava-hf/llava-v1.6-mistral-7b-hf", | |
# See all LLaVA-NeXT models at https://huggingface.co/models?filter=llava_next | |
] | |
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): | |
""" | |
Calculate the shape of the image patch grid after the preprocessing for images of any resolution. | |
Args: | |
image_size (`tuple`): | |
The size of the input image in the format (width, height). | |
grid_pinpoints (`List`): | |
A list containing possible resolutions. Each item in the list should be a tuple or list | |
of the form `(height, width)`. | |
patch_size (`int`): | |
The size of each image patch. | |
Returns: | |
tuple: The shape of the image patch grid in the format (width, height). | |
""" | |
if not isinstance(grid_pinpoints, list): | |
raise ValueError("grid_pinpoints should be a list of tuples or lists") | |
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate | |
if not isinstance(image_size, (list, tuple)): | |
assert isinstance(image_size, (torch.Tensor, np.ndarray)), f'image_size invalid type: {type(image_size)} | {image_size}' | |
image_size = image_size.tolist() | |
height, width = select_best_resolution(image_size, grid_pinpoints) | |
return height // patch_size, width // patch_size | |
def image_size_to_num_patches(image_size, grid_pinpoints, patch_size: int): | |
""" | |
Calculate the shape of the image patch grid after the preprocessing for images of any resolution. | |
Args: | |
image_size (`tuple`): | |
The size of the input image in the format (height, width). ? | |
grid_pinpoints (`List`): | |
A list containing possible resolutions. Each item in the list should be a tuple or list | |
of the form `(height, width)`. | |
patch_size (`int`): | |
The size of each image patch. | |
Returns: | |
tuple: The shape of the image patch grid in the format (height, width). ? | |
""" | |
if not isinstance(grid_pinpoints, list): | |
raise ValueError("grid_pinpoints should be a list of tuples or lists") | |
# ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate | |
if not isinstance(image_size, (list, tuple)): | |
assert isinstance(image_size, (torch.Tensor, np.ndarray)), f'image_size invalid type: {type(image_size)} | {image_size}' | |
image_size = image_size.tolist() | |
best_resolution = select_best_resolution(image_size, grid_pinpoints) | |
height, width = best_resolution | |
num_patches = 0 | |
for i in range(0, height, patch_size): | |
for j in range(0, width, patch_size): | |
num_patches += 1 | |
# add the base patch | |
num_patches += 1 | |
return num_patches | |
def unpad_image(tensor, original_size): | |
""" | |
Unpads a PyTorch tensor of a padded and resized image. | |
Args: | |
tensor (`torch.Tensor`): | |
The image tensor, assumed to be of shape (num_channels, height, width). | |
original_size (`tuple`): | |
The original size of the image (height, width). | |
Returns: | |
`torch.Tensor`: The unpadded image tensor. | |
""" | |
original_height, original_width = original_size | |
current_height, current_width = tensor.shape[1:] | |
original_aspect_ratio = original_width / original_height | |
current_aspect_ratio = current_width / current_height | |
if original_aspect_ratio > current_aspect_ratio: | |
scale_factor = current_width / original_width | |
new_height = int(original_height * scale_factor) | |
padding = (current_height - new_height) // 2 | |
unpadded_tensor = tensor[:, padding : current_height - padding, :] | |
else: | |
scale_factor = current_height / original_height | |
new_width = int(original_width * scale_factor) | |
padding = (current_width - new_width) // 2 | |
unpadded_tensor = tensor[:, :, padding : current_width - padding] | |
return unpadded_tensor | |
# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->LlavaNext | |
class LlavaNextCausalLMOutputWithPast(ModelOutput): | |
""" | |
Base class for LlavaNext causal language model (or autoregressive) outputs. | |
Args: | |
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): | |
Language modeling loss (for next-token prediction). | |
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): | |
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | |
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape | |
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) | |
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see | |
`past_key_values` input) to speed up sequential decoding. | |
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): | |
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + | |
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. | |
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. | |
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): | |
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, | |
sequence_length)`. | |
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | |
heads. | |
image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): | |
Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images, | |
sequence_length, hidden_size)`. | |
image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver | |
""" | |
loss: Optional[torch.FloatTensor] = None | |
logits: torch.FloatTensor = None | |
past_key_values: Optional[List[torch.FloatTensor]] = None | |
hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
attentions: Optional[Tuple[torch.FloatTensor]] = None | |
image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext | |
class LlavaNextMultiModalProjector(nn.Module): | |
def __init__(self, config: LlavaNextConfig): | |
super().__init__() | |
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True) | |
self.act = ACT2FN[config.projector_hidden_act] | |
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) | |
def forward(self, image_features): | |
hidden_states = self.linear_1(image_features) | |
hidden_states = self.act(hidden_states) | |
hidden_states = self.linear_2(hidden_states) | |
return hidden_states | |
LLAVA_NEXT_START_DOCSTRING = r""" | |
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the | |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads | |
etc.) | |
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. | |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage | |
and behavior. | |
Parameters: | |
config ([`LlavaNextConfig`] or [`LlavaNextVisionConfig`]): | |
Model configuration class with all the parameters of the model. Initializing with a config file does not | |
load the weights associated with the model, only the configuration. Check out the | |
[`~PreTrainedModel.from_pretrained`] method to load the model weights. | |
""" | |
# Copied from transformers.models.llava.modeling_llava.LlavaPreTrainedModel with Llava->LlavaNext,llava->llava_next | |
class LlavaNextPreTrainedModel(PreTrainedModel): | |
config_class = LlavaNextConfig | |
base_model_prefix = "model" | |
supports_gradient_checkpointing = True | |
_no_split_modules = ["LlavaNextVisionAttention"] | |
_skip_keys_device_placement = "past_key_values" | |
_supports_flash_attn_2 = True | |
def _init_weights(self, module): | |
# important: this ported version of LlavaNext isn't meant for training from scratch - only | |
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase | |
# https://github.com/haotian-liu/LLaVA/tree/main/llava_next should serve for that purpose | |
std = ( | |
self.config.initializer_range | |
if hasattr(self.config, "initializer_range") | |
else self.config.text_config.initializer_range | |
) | |
if hasattr(module, "class_embedding"): | |
module.class_embedding.data.normal_(mean=0.0, std=std) | |
if isinstance(module, (nn.Linear, nn.Conv2d)): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=std) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
def _supports_sdpa(self): | |
""" | |
Retrieve language_model's attribute to check whether the model supports | |
SDPA or not. | |
""" | |
return self.language_model._supports_sdpa | |
LLAVA_NEXT_INPUTS_DOCSTRING = r""" | |
Args: | |
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide | |
it. | |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
[`PreTrainedTokenizer.__call__`] for details. | |
[What are input IDs?](../glossary#input-ids) | |
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): | |
The tensors corresponding to the input images. Pixel values can be obtained using | |
[`AutoImageProcessor`]. See [`LlavaNextImageProcessor.__call__`] for details. [`LlavaProcessor`] uses | |
[`LlavaNextImageProcessor`] for processing images. | |
image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*): | |
The sizes of the images in the batch, being (height, width) for each image. | |
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
[What are attention masks?](../glossary#attention-mask) | |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
[`PreTrainedTokenizer.__call__`] for details. | |
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see | |
`past_key_values`). | |
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] | |
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more | |
information on the default strategy. | |
- 1 indicates the head is **not masked**, | |
- 0 indicates the head is **masked**. | |
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, | |
config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) | |
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape | |
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape | |
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. | |
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention | |
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. | |
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that | |
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all | |
`decoder_input_ids` of shape `(batch_size, sequence_length)`. | |
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): | |
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This | |
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the | |
model's internal embedding lookup matrix. | |
vision_feature_layer (`int`, *optional*, defaults to -2): | |
The index of the layer to select the vision feature. | |
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): | |
The feature selection strategy used to select the vision feature from the vision backbone. | |
Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features. | |
If `"full"`, the full vision features are used. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see | |
`past_key_values`). | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned | |
tensors for more detail. | |
output_hidden_states (`bool`, *optional*): | |
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for | |
more detail. | |
return_dict (`bool`, *optional*): | |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
""" | |
class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel): | |
def __init__(self, config: LlavaNextConfig): | |
super().__init__(config) | |
self.vision_tower = AutoModel.from_config(config.vision_config) | |
self.multi_modal_projector = LlavaNextMultiModalProjector(config) | |
self.image_newline = nn.Parameter(torch.empty(config.text_config.hidden_size, dtype=self.dtype)) | |
self.vocab_size = config.text_config.vocab_size | |
self.language_model = AutoModelForCausalLM.from_config( | |
config.text_config, attn_implementation=config._attn_implementation | |
) | |
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 | |
self.post_init() | |
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings | |
def get_input_embeddings(self): | |
return self.language_model.get_input_embeddings() | |
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings | |
def set_input_embeddings(self, value): | |
self.language_model.set_input_embeddings(value) | |
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings | |
def get_output_embeddings(self): | |
return self.language_model.get_output_embeddings() | |
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings | |
def set_output_embeddings(self, new_embeddings): | |
self.language_model.set_output_embeddings(new_embeddings) | |
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder | |
def set_decoder(self, decoder): | |
self.language_model.set_decoder(decoder) | |
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder | |
def get_decoder(self): | |
return self.language_model.get_decoder() | |
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.tie_weights | |
def tie_weights(self): | |
return self.language_model.tie_weights() | |
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.resize_token_embeddings | |
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: | |
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) | |
# update vocab size | |
self.config.text_config.vocab_size = model_embeds.num_embeddings | |
self.vocab_size = model_embeds.num_embeddings | |
return model_embeds | |
def _merge_input_ids_with_image_features( | |
self, image_features, feature_lens, inputs_embeds, input_ids, attention_mask, position_ids=None, | |
labels=None, image_token_index=None, | |
ignore_index=-100, | |
padding_side: Optional[str] = "left", | |
): | |
""" | |
Args: | |
input_ids: [batch_size, tlen] | |
input_embeds: [batch_size, tlen, dt] | |
image_features: [all_feat_lens, di] | |
feature_lens: [num_images], | |
num_images=number of image in the batch | |
each value is the length of embedding featres of each image | |
Note: sum(feature_lens) == all_feat_lens | |
labels: None or [batch_size, tlen] --> must extend labels to input_ids, | |
padding_side: `left` or `right`, | |
must specify for generation because we cannot tell that from input_ids | |
see below | |
Returns: | |
final_embedding, final_attention_mask, position_ids, final_labels | |
Explanation: | |
each image has variable length embeddings, with length specified by feature_lens | |
image_features is concatenation of all visual embed vectors | |
task: fill each <image> with the correct number of visual embeddings | |
Example: | |
X (5 patches), Y (3 patches), Z (8) | |
X, Y is on the same sequence (in-context learning) | |
if right padding | |
input_ids: [ | |
a b c d e f X g h i j k Y l m | |
o p q r Z s t u v _ _ _ _ _ _ | |
] | |
input_ids should be: [ | |
a b c d e f X X X X X g h i j k Y Y Y l m | |
o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _ | |
] | |
labels should be: [ | |
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m | |
o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _ | |
] | |
elif left padding | |
input_ids: [ | |
a b c d e f X g h i j k Y l m | |
_ _ _ _ _ _ o p q r Z s t u v | |
] | |
input_ids should be: [ | |
a b c d e f X X X X X g h i j k Y Y Y l m | |
_ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v | |
] | |
labels should be: [ | |
a b c d e f _ _ _ _ _ g h i j k _ _ _ l m | |
_ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v | |
] | |
Edge cases: | |
* If tokens are same but image token sizes are different, then cannot infer left or right padding | |
```python | |
cat_img = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) | |
chart_img = Image.open(requests.get("https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true", stream=True).raw) | |
prompts = [ | |
"[INST] <image>\nWhat is shown in this image? [/INST]", | |
"[INST] <image>\nWhat is shown in this image? [/INST]", | |
] | |
inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda") | |
chart_img has 2634 tokens, while cat_img has 2340 tokens | |
``` | |
input_ids: [ | |
a b c d X g h | |
i j Y k l m n | |
] | |
where X is 3 tokens while Y is 5, this mean after merge | |
if left-padding (batched generation) | |
input_ids should be: [ | |
_ _ a b c d X X X g h | |
i j Y Y Y Y Y k l m n | |
] | |
elif (right padding) (training) | |
input_ids should be: [ | |
a b c d X X X g h _ _ | |
i j Y Y Y Y Y k l m n | |
] | |
""" | |
image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index | |
ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index | |
with torch.no_grad(): | |
# ! in llava 1.6, number of patches is variable | |
num_images = feature_lens.size(0) | |
num_image_features, embed_dim = image_features.shape | |
assert feature_lens.sum() == num_image_features, f'{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}' | |
batch_size, sequence_length = input_ids.shape | |
_left_padding = torch.any(attention_mask[:, 0] == 0) | |
_right_padding = torch.any(attention_mask[:, -1] == 0) | |
if _left_padding and not _right_padding: | |
left_padding = True | |
elif not _left_padding and _right_padding: | |
left_padding = False | |
elif not _left_padding and not _right_padding: | |
# both side is 1, so cannot tell | |
left_padding = padding_side == "left" | |
else: | |
# invalid attention_mask | |
raise ValueError(f'both side of attention_mask has zero, invalid. {attention_mask}') | |
# Whether to turn off right padding | |
# 1. Create a mask to know where special image tokens are | |
special_image_token_mask = input_ids == image_token_index | |
# special_image_token_mask: [bsz, seqlen] | |
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) | |
# num_special_image_tokens: [bsz] | |
# Reserve for padding of num_images | |
total_num_special_image_tokens = torch.sum(special_image_token_mask) | |
assert total_num_special_image_tokens == num_images, ( | |
f'{total_num_special_image_tokens=} != {num_images=} | {image_features.shape} {input_ids}' | |
) | |
# Compute the maximum embed dimension | |
# max_image_feature_lens is max_feature_lens per batch | |
feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0) | |
feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=feature_lens.device) | |
embed_sequence_lengths = (attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum | |
max_embed_dim = embed_sequence_lengths.max() | |
batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1)) | |
# 2. Compute the positions where text should be written | |
# Calculate new positions for text tokens in merged image-text sequence. | |
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. | |
# `torch.cumsum` computes how each image token shifts subsequent text token positions. | |
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. | |
# ! instead of special_image_token_mask * (num_image_patches - 1) | |
# special_image_token_mask * (num_feature_len - 1) | |
special_image_len_mask = special_image_token_mask.clone().long() | |
special_image_len_mask[special_image_len_mask == 1] = feature_lens - 1 | |
new_token_positions = torch.cumsum((special_image_len_mask + 1), -1) - 1 | |
if left_padding: | |
# shift right token positions so that they are ending at the same number | |
new_token_positions += (new_token_positions[:, -1].max() - new_token_positions[:, -1:]) | |
text_to_overwrite = new_token_positions[batch_indices, non_image_indices] | |
# 3. Create the full embedding, already padded to the maximum position | |
final_embedding = torch.zeros( | |
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device | |
) | |
final_attention_mask = torch.zeros( | |
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device | |
) | |
final_labels = None | |
if labels is not None: | |
final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long) | |
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually | |
# set the corresponding tensors into their correct target device. | |
target_device = inputs_embeds.device | |
batch_indices, non_image_indices, text_to_overwrite = ( | |
batch_indices.to(target_device), | |
non_image_indices.to(target_device), | |
text_to_overwrite.to(target_device), | |
) | |
attention_mask = attention_mask.to(target_device) | |
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] | |
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features | |
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] | |
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] | |
if labels is not None: | |
final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] | |
# 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling | |
with torch.no_grad(): | |
image_to_overwrite = torch.all(final_embedding == 0, dim=-1) | |
if left_padding: | |
# exclude padding on the left | |
val = (max_embed_dim - torch.arange(max_embed_dim).unsqueeze(0).to(target_device).expand(batch_size, max_embed_dim)) <= embed_sequence_lengths[:, None].to(target_device) | |
image_to_overwrite &= val | |
else: | |
# exclude padding on the right | |
val = torch.arange(max_embed_dim).unsqueeze(0).to(target_device).expand(batch_size, max_embed_dim) < embed_sequence_lengths[:, None].to(target_device) | |
image_to_overwrite &= val | |
if image_to_overwrite.sum() != num_image_features: | |
raise ValueError( | |
f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. " | |
f"The number of image tokens is {torch.sum(special_image_token_mask)} while" | |
f" the number of image given to the model is {num_images}. " | |
f"This prevents correct indexing and breaks batch generation." | |
) | |
final_embedding[image_to_overwrite] = image_features.to(target_device) | |
final_attention_mask |= image_to_overwrite | |
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) | |
if not left_padding: | |
# Making sure its the same | |
seq_lens = final_attention_mask.sum(-1) | |
for i, (mask, seq_len) in enumerate(zip(final_attention_mask, seq_lens)): | |
# seq_len = mask.sum(-1) | |
assert torch.all(mask[:seq_len] == 1), f'final 1 mask[{i}]: {seq_len=} {final_attention_mask.size()=} {final_attention_mask.tolist()=} \n{text_to_overwrite.tolist()=}' | |
assert torch.all(mask[seq_len:] == 0), f'final 0 mask[{i}]: {seq_len=} {final_attention_mask.size()=} {final_attention_mask.tolist()=}' | |
return final_embedding, final_attention_mask, position_ids, final_labels | |
def pack_image_features(self, image_features, image_sizes, image_newline=None): | |
""" | |
List of image features | |
image_features: list (size num_images) [patches, feat, dim] | |
Returns: | |
image_features: [all_feat_len, embed_dim] | |
feature_lens: [num_images] # number of feature_lens | |
""" | |
new_image_features = [] | |
feature_lens = [] | |
for image_idx, image_feature in enumerate(image_features): | |
if image_feature.shape[0] > 1: | |
base_image_feature = image_feature[0] | |
image_feature = image_feature[1:] | |
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size | |
if height * width != base_image_feature.shape[0]: | |
raise ValueError("The number of patches is not consistent with the image size.") | |
num_patch_width, num_patch_height = get_anyres_image_grid_shape( | |
image_sizes[image_idx], | |
self.config.image_grid_pinpoints, | |
self.config.vision_config.image_size, | |
) | |
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) | |
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() | |
image_feature = image_feature.flatten(1, 2).flatten(2, 3) | |
image_feature = unpad_image(image_feature, image_sizes[image_idx]) | |
if image_newline is not None: | |
image_feature = torch.cat( | |
( | |
image_feature, | |
image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature), | |
), | |
dim=-1, | |
) | |
image_feature = image_feature.flatten(1, 2).transpose(0, 1) | |
image_feature = torch.cat((base_image_feature, image_feature), dim=0) | |
else: | |
image_feature = image_feature[0] | |
if image_newline is not None: | |
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) | |
new_image_features.append(image_feature) | |
feature_lens.append(image_feature.size(0)) | |
image_features = torch.cat(new_image_features, dim=0) | |
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) | |
return image_features, feature_lens | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
pixel_values: torch.FloatTensor = None, | |
image_sizes: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
vision_feature_layer: Optional[int] = None, | |
vision_feature_select_strategy: Optional[str] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
padding_side: Optional[str] = "left", | |
) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]: | |
r""" | |
Args: | |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
Returns: | |
Example: | |
```python | |
>>> from PIL import Image | |
>>> import requests | |
>>> from transformers import AutoProcessor, LlavaNextForConditionalGeneration | |
>>> model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
>>> processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
>>> prompt = "[INST] <image>\nWhat is shown in this image? [/INST]" | |
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" | |
>>> image = Image.open(requests.get(url, stream=True).raw) | |
>>> inputs = processor(text=prompt, images=image, return_tensors="pt") | |
>>> # Generate | |
>>> generate_ids = model.generate(**inputs, max_length=30) | |
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
"[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot (...)" | |
```""" | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
vision_feature_layer = ( | |
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer | |
) | |
vision_feature_select_strategy = ( | |
vision_feature_select_strategy | |
if vision_feature_select_strategy is not None | |
else self.config.vision_feature_select_strategy | |
) | |
if inputs_embeds is None: | |
# 1. Extract the input embeddings | |
# In case image_token_index is not in the embeddings (extra token but embedding don't have it) | |
for_inputs_embeds_ids = input_ids.clone() | |
for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0 | |
inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids) | |
# 2. Merge text and images | |
if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0: | |
# ! infer image_num_patches from image_sizes | |
image_num_patches = [ | |
image_size_to_num_patches( | |
image_size=imsize, | |
grid_pinpoints=self.config.image_grid_pinpoints, | |
patch_size=self.config.vision_config.image_size | |
) | |
for imsize in image_sizes | |
] | |
image_features = self.vision_tower(pixel_values, output_hidden_states=True) | |
selected_image_feature = image_features.hidden_states[vision_feature_layer] | |
if vision_feature_select_strategy == "default": | |
selected_image_feature = selected_image_feature[:, 1:] | |
elif vision_feature_select_strategy == "full": | |
selected_image_feature = selected_image_feature | |
image_features = self.multi_modal_projector(selected_image_feature) | |
image_features = torch.split(image_features, image_num_patches, dim=0) | |
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad" | |
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size | |
image_features, feature_lens = self.pack_image_features( | |
image_features, image_sizes, | |
image_newline=self.image_newline, | |
) | |
inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_image_features( | |
image_features, feature_lens, inputs_embeds, input_ids, attention_mask, position_ids, | |
labels=labels, | |
padding_side=padding_side, | |
) | |
# pixel_values is not None but is empty ---> text only cases | |
elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0: | |
# there is no images | |
pass | |
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of | |
# generation with cache | |
elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: | |
# Retrieve the first layer to inspect the logits and mask out the hidden states | |
# that are set to 0 | |
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] | |
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 | |
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) | |
# Get the target length | |
target_seqlen = first_layer_past_key_value.shape[-1] + 1 | |
extended_attention_mask = torch.ones( | |
(attention_mask.shape[0], target_seqlen - attention_mask.shape[1]), | |
dtype=attention_mask.dtype, | |
device=attention_mask.device, | |
) | |
# Filter out only the tokens that can be un-attended, this can happen | |
# if one uses Llava + Fused modules where the cache on the | |
# first iteration is already big enough, or if one passes custom cache | |
valid_indices = non_attended_tokens < extended_attention_mask.size(-1) | |
new_batch_index = batch_index[valid_indices] | |
new_non_attended_tokens = non_attended_tokens[valid_indices] | |
# Zero-out the places where we don't need to attend | |
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 | |
# !(nxphi47) must ensure left-padding | |
# attention_mask is the new in-coming mask, while extended_attention_mask is the previous one | |
assert padding_side == "left", f"{padding_side=} is invalid for batched generation mode" | |
attention_mask = torch.cat((extended_attention_mask, attention_mask), dim=1) | |
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 | |
outputs = self.language_model( | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
logits = outputs[0] | |
loss = None | |
if labels is not None: | |
# Shift so that tokens < n predict n | |
if attention_mask is not None: | |
shift_attention_mask = attention_mask[..., 1:] | |
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() | |
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() | |
else: | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
# Flatten the tokens | |
loss_fct = nn.CrossEntropyLoss() | |
loss = loss_fct( | |
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) | |
) | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return (loss,) + output if loss is not None else output | |
return LlavaNextCausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
def prepare_inputs_for_generation( | |
self, | |
input_ids, | |
past_key_values=None, | |
inputs_embeds=None, | |
pixel_values=None, | |
image_sizes=None, | |
attention_mask=None, | |
**kwargs, | |
): | |
if past_key_values is not None: | |
if isinstance(past_key_values, Cache): | |
cache_length = past_key_values.get_seq_length() | |
past_length = past_key_values.seen_tokens | |
else: | |
cache_length = past_length = past_key_values[0][0].shape[2] | |
# Keep only the unprocessed tokens: | |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as | |
# input) | |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: | |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] | |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
# input_ids based on the past_length. | |
elif past_length < input_ids.shape[1]: | |
input_ids = input_ids[:, past_length:] | |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
elif self.config.image_token_index in input_ids: | |
input_ids = input_ids[:, input_ids.shape[1] - 1 :] | |
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the | |
# older attention values, as their corresponding values are not part of the input. | |
if cache_length < past_length and attention_mask is not None: | |
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] | |
position_ids = kwargs.get("position_ids", None) | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if past_key_values: | |
position_ids = position_ids[:, -input_ids.shape[1] :] | |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
if inputs_embeds is not None and past_key_values is None: | |
model_inputs = {"inputs_embeds": inputs_embeds} | |
else: | |
model_inputs = {"input_ids": input_ids} | |
model_inputs.update( | |
{ | |
"position_ids": position_ids, | |
"past_key_values": past_key_values, | |
"use_cache": kwargs.get("use_cache"), | |
"attention_mask": attention_mask, | |
"pixel_values": pixel_values, | |
"image_sizes": image_sizes, | |
} | |
) | |
return model_inputs | |
# Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._reorder_cache | |
def _reorder_cache(self, *args, **kwargs): | |
return self.language_model._reorder_cache(*args, **kwargs) | |