japanese-llava-small-stair / modeling_llava.py
ohashi56225's picture
Upload LlavaForConditionalGeneration
20150e3
# Copyright 2023 Stability AI team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union, Any
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
PreTrainedModel,
CLIPVisionModel,
)
from transformers.utils import logging, ModelOutput
from .configuration_llava import LlavaConfig
logger = logging.get_logger(__name__)
@dataclass
class LlavaForConditionalGenerationModelOutput(ModelOutput):
loss: Optional[Tuple[torch.FloatTensor]] = None
logits: Optional[Tuple[torch.FloatTensor]] = None
vision_outputs: Optional[torch.FloatTensor] = None
language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
self[k]
if k not in ["vision_outputs", "language_model_outputs"]
else getattr(self, k).to_tuple()
for k in self.keys()
)
class LlavaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = LlavaConfig
base_model_prefix = "llava"
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2PreTrainedModel._init_weights with Blip2->InstructBlip
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_range
if (
isinstance(module, nn.Conv2d)
or isinstance(module, nn.Embedding)
or isinstance(module, nn.Linear)
):
module.weight.data.normal_(mean=0.0, std=factor)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class LlavaForConditionalGeneration(LlavaPreTrainedModel):
config_class = LlavaConfig
main_input_name = "pixel_values"
_no_split_modules = []
def __init__(self, config: LlavaConfig):
super().__init__(config)
self.vision_model = CLIPVisionModel(config.vision_config)
if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(config.text_config)
else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
if language_model._no_split_modules is not None:
self._no_split_modules.extend(language_model._no_split_modules)
if language_model._keep_in_fp32_modules is not None:
self._keep_in_fp32_modules.extend(language_model._keep_in_fp32_modules)
self.language_model = language_model
modules = [
nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size)
]
for _ in range(1, config.mlp_config.num_hidden_layers):
modules.append(nn.GELU())
modules.append(
nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size
)
)
self.mlp = nn.Sequential(*modules)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def get_output_embeddings(self) -> nn.Module:
return self.language_model.get_output_embeddings()
def get_encoder(self):
return self.language_model.get_encoder()
def get_decoder(self):
return self.language_model.get_decoder()
def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared
def _preprocess_accelerate(self):
r"""
Some pre-processing hacks to make the model `accelerate` compatible. Check
https://github.com/huggingface/transformers/pull/21707 for more details.
"""
hf_device_map = self.hf_device_map
if (
len(hf_device_map) > 1
and "language_model" not in hf_device_map
and torch.cuda.device_count() > 1
):
# warn users about unexpected behavior when using multi-GPU + InstructBLIP + `accelerate`.
logger.warning(
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
" Please pass a `device_map` that contains `language_model` to remove this warning."
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
" more details on creating a `device_map` for large models.",
)
if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = (
True # For `generate` compatibility
)
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, LlavaForConditionalGenerationModelOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# step 1: forward the images through the vision encoder,
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
return_dict=return_dict,
output_hidden_states=True,
)
# (bsz, seq len, hidden_size)
image_embeds = vision_outputs.hidden_states[self.config.vision_select_layer]
if self.config.vision_select_feature == "patch":
image_embeds = image_embeds[:, 1:]
elif self.config.vision_select_feature == "cls_patch":
image_embeds = image_embeds
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
# step 2: forward the image embeddings through the mlp
image_embeds = self.mlp(image_embeds)
image_attention_mask = torch.ones(
image_embeds.size()[:-1], device=image_embeds.device
)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
# step 3: concatenate
inputs_embeds = torch.cat(
[image_embeds, inputs_embeds.to(image_embeds.device)],
dim=1,
)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, device=input_ids.device)
attention_mask = torch.cat(
[image_attention_mask.to(attention_mask.device), attention_mask],
dim=1,
)
if self.config.use_decoder_only_language_model:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
# we compute the loss here since we need to take into account the sequence length of the query embeds
if labels is not None:
labels = labels.to(logits.device)
logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous().to(logits.device)
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction="mean")
loss = loss_fct(
shift_logits.view(-1, self.config.text_config.vocab_size),
shift_labels.view(-1),
)
else:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
if not return_dict:
output = (logits, vision_outputs, outputs)
return ((loss,) + output) if loss is not None else output
return LlavaForConditionalGenerationModelOutput(
loss=loss,
logits=logits,
vision_outputs=vision_outputs,
language_model_outputs=outputs,
)
def get_image_embeds(self, pixel_values: torch.FloatTensor):
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=True,
)
image_embeds = vision_outputs.hidden_states[self.config.vision_select_layer]
if self.config.vision_select_feature == "patch":
image_embeds = image_embeds[:, 1:]
elif self.config.vision_select_feature == "cls_patch":
image_embeds = image_embeds
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
image_embeds = self.mlp(image_embeds)
image_attention_mask = torch.ones(
image_embeds.size()[:-1], device=image_embeds.device
)
return dict(
image_embeds=image_embeds,
image_attention_mask=image_attention_mask,
)
def prepare_for_lm_generation(
self,
pixel_values: torch.FloatTensor,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
batch_size = pixel_values.shape[0]
vision_outputs = self.get_image_embeds(pixel_values)
image_embeds = vision_outputs["image_embeds"]
image_attention_mask = vision_outputs["image_attention_mask"]
if input_ids is None:
input_ids = (
torch.LongTensor([[self.config.text_config.bos_token_id]])
.repeat(batch_size, 1)
.to(image_embeds.device)
)
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
attention_mask = torch.cat(
[
image_attention_mask,
attention_mask.to(image_attention_mask.device),
],
dim=1,
)
# concatenate query embeddings with prompt embeddings
inputs_embeds = self.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat(
[image_embeds, inputs_embeds.to(image_embeds.device)],
dim=1,
)
return dict(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
@torch.no_grad()
def generate(
self,
pixel_values: torch.FloatTensor,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**generate_kwargs,
) -> torch.LongTensor:
if hasattr(self, "hf_device_map"):
# preprocess for `accelerate`
self._preprocess_accelerate()
encodings = self.prepare_for_lm_generation(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
)
outputs = self.language_model.generate(
**encodings,
**generate_kwargs,
)
return outputs