|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple, Union, Any |
|
from dataclasses import dataclass |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import CrossEntropyLoss |
|
|
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoModelForSeq2SeqLM, |
|
PreTrainedModel, |
|
CLIPVisionModel, |
|
) |
|
|
|
from transformers.utils import logging, ModelOutput |
|
from .configuration_llava import LlavaConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
@dataclass |
|
class LlavaForConditionalGenerationModelOutput(ModelOutput): |
|
loss: Optional[Tuple[torch.FloatTensor]] = None |
|
logits: Optional[Tuple[torch.FloatTensor]] = None |
|
vision_outputs: Optional[torch.FloatTensor] = None |
|
language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
def to_tuple(self) -> Tuple[Any]: |
|
return tuple( |
|
self[k] |
|
if k not in ["vision_outputs", "language_model_outputs"] |
|
else getattr(self, k).to_tuple() |
|
for k in self.keys() |
|
) |
|
|
|
|
|
class LlavaPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = LlavaConfig |
|
base_model_prefix = "llava" |
|
|
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
factor = self.config.initializer_range |
|
if ( |
|
isinstance(module, nn.Conv2d) |
|
or isinstance(module, nn.Embedding) |
|
or isinstance(module, nn.Linear) |
|
): |
|
module.weight.data.normal_(mean=0.0, std=factor) |
|
if hasattr(module, "bias") and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
elif isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
|
|
class LlavaForConditionalGeneration(LlavaPreTrainedModel): |
|
config_class = LlavaConfig |
|
main_input_name = "pixel_values" |
|
_no_split_modules = [] |
|
|
|
def __init__(self, config: LlavaConfig): |
|
super().__init__(config) |
|
|
|
self.vision_model = CLIPVisionModel(config.vision_config) |
|
if config.use_decoder_only_language_model: |
|
language_model = AutoModelForCausalLM.from_config(config.text_config) |
|
else: |
|
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) |
|
|
|
if language_model._no_split_modules is not None: |
|
self._no_split_modules.extend(language_model._no_split_modules) |
|
|
|
if language_model._keep_in_fp32_modules is not None: |
|
self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules) |
|
|
|
self.language_model = language_model |
|
|
|
modules = [ |
|
nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size) |
|
] |
|
for _ in range(1, config.mlp_config.num_hidden_layers): |
|
modules.append(nn.GELU()) |
|
modules.append( |
|
nn.Linear( |
|
config.text_config.hidden_size, config.text_config.hidden_size |
|
) |
|
) |
|
self.mlp = nn.Sequential(*modules) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.language_model.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value): |
|
self.language_model.set_input_embeddings(value) |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.language_model.set_output_embeddings(new_embeddings) |
|
|
|
def get_output_embeddings(self) -> nn.Module: |
|
return self.language_model.get_output_embeddings() |
|
|
|
def get_encoder(self): |
|
return self.language_model.get_encoder() |
|
|
|
def get_decoder(self): |
|
return self.language_model.get_decoder() |
|
|
|
def _tie_weights(self): |
|
if not self.config.use_decoder_only_language_model: |
|
self.language_model.encoder.embed_tokens = self.language_model.shared |
|
self.language_model.decoder.embed_tokens = self.language_model.shared |
|
|
|
def _preprocess_accelerate(self): |
|
r""" |
|
Some pre-processing hacks to make the model `accelerate` compatible. Check |
|
https://github.com/huggingface/transformers/pull/21707 for more details. |
|
""" |
|
hf_device_map = self.hf_device_map |
|
|
|
if ( |
|
len(hf_device_map) > 1 |
|
and "language_model" not in hf_device_map |
|
and torch.cuda.device_count() > 1 |
|
): |
|
|
|
logger.warning( |
|
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script" |
|
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." |
|
" Please pass a `device_map` that contains `language_model` to remove this warning." |
|
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for" |
|
" more details on creating a `device_map` for large models.", |
|
) |
|
|
|
if hasattr(self.language_model, "_hf_hook"): |
|
self.language_model._hf_hook.io_same_device = ( |
|
True |
|
) |
|
|
|
def forward( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
input_ids: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
decoder_attention_mask: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, LlavaForConditionalGenerationModelOutput]: |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
|
|
vision_outputs = self.vision_model( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
return_dict=return_dict, |
|
output_hidden_states=True, |
|
) |
|
|
|
image_embeds = vision_outputs.hidden_states[self.config.vision_select_layer] |
|
if self.config.vision_select_feature == "patch": |
|
image_embeds = image_embeds[:, 1:] |
|
elif self.config.vision_select_feature == "cls_patch": |
|
image_embeds = image_embeds |
|
else: |
|
raise ValueError(f"Unexpected select feature: {self.select_feature}") |
|
|
|
|
|
image_embeds = self.mlp(image_embeds) |
|
image_attention_mask = torch.ones( |
|
image_embeds.size()[:-1], device=image_embeds.device |
|
) |
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) |
|
|
|
|
|
inputs_embeds = torch.cat( |
|
[image_embeds, inputs_embeds.to(image_embeds.device)], |
|
dim=1, |
|
) |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids, device=input_ids.device) |
|
|
|
attention_mask = torch.cat( |
|
[image_attention_mask.to(attention_mask.device), attention_mask], |
|
dim=1, |
|
) |
|
|
|
if self.config.use_decoder_only_language_model: |
|
outputs = self.language_model( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
logits = outputs.logits if return_dict else outputs[0] |
|
loss = None |
|
|
|
if labels is not None: |
|
labels = labels.to(logits.device) |
|
logits = logits[:, -labels.size(1) :, :] |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous().to(logits.device) |
|
|
|
|
|
loss_fct = CrossEntropyLoss(reduction="mean") |
|
|
|
loss = loss_fct( |
|
shift_logits.view(-1, self.config.text_config.vocab_size), |
|
shift_labels.view(-1), |
|
) |
|
else: |
|
outputs = self.language_model( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
decoder_input_ids=decoder_input_ids, |
|
decoder_attention_mask=decoder_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
labels=labels, |
|
) |
|
loss = outputs.loss if return_dict else outputs[0] |
|
logits = outputs.logits if return_dict else outputs[1] |
|
|
|
if not return_dict: |
|
output = (logits, vision_outputs, outputs) |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return LlavaForConditionalGenerationModelOutput( |
|
loss=loss, |
|
logits=logits, |
|
vision_outputs=vision_outputs, |
|
language_model_outputs=outputs, |
|
) |
|
|
|
def get_image_embeds(self, pixel_values: torch.FloatTensor): |
|
vision_outputs = self.vision_model( |
|
pixel_values=pixel_values, |
|
output_hidden_states=True, |
|
) |
|
image_embeds = vision_outputs.hidden_states[self.config.vision_select_layer] |
|
if self.config.vision_select_feature == "patch": |
|
image_embeds = image_embeds[:, 1:] |
|
elif self.config.vision_select_feature == "cls_patch": |
|
image_embeds = image_embeds |
|
else: |
|
raise ValueError(f"Unexpected select feature: {self.select_feature}") |
|
|
|
image_embeds = self.mlp(image_embeds) |
|
image_attention_mask = torch.ones( |
|
image_embeds.size()[:-1], device=image_embeds.device |
|
) |
|
return dict( |
|
image_embeds=image_embeds, |
|
image_attention_mask=image_attention_mask, |
|
) |
|
|
|
def prepare_for_lm_generation( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
): |
|
batch_size = pixel_values.shape[0] |
|
vision_outputs = self.get_image_embeds(pixel_values) |
|
image_embeds = vision_outputs["image_embeds"] |
|
image_attention_mask = vision_outputs["image_attention_mask"] |
|
|
|
if input_ids is None: |
|
input_ids = ( |
|
torch.LongTensor([[self.config.text_config.bos_token_id]]) |
|
.repeat(batch_size, 1) |
|
.to(image_embeds.device) |
|
) |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids) |
|
attention_mask = torch.cat( |
|
[ |
|
image_attention_mask, |
|
attention_mask.to(image_attention_mask.device), |
|
], |
|
dim=1, |
|
) |
|
|
|
|
|
inputs_embeds = self.get_input_embeddings()(input_ids) |
|
inputs_embeds = torch.cat( |
|
[image_embeds, inputs_embeds.to(image_embeds.device)], |
|
dim=1, |
|
) |
|
return dict(inputs_embeds=inputs_embeds, attention_mask=attention_mask) |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
**generate_kwargs, |
|
) -> torch.LongTensor: |
|
if hasattr(self, "hf_device_map"): |
|
|
|
self._preprocess_accelerate() |
|
encodings = self.prepare_for_lm_generation( |
|
pixel_values=pixel_values, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
) |
|
outputs = self.language_model.generate( |
|
**encodings, |
|
**generate_kwargs, |
|
) |
|
return outputs |