Spaces:
Sleeping
Sleeping
import warnings | |
from dataclasses import dataclass | |
from typing import List | |
import torch | |
from einops import rearrange | |
from PIL import Image | |
from torch import nn | |
from transformers.models.bert import BertConfig, BertModel | |
from transformers.models.bloom import BloomConfig, BloomForCausalLM, BloomTokenizerFast | |
from transformers.models.convnext import ConvNextImageProcessor | |
from transformers.models.convnextv2 import ConvNextV2Config | |
from transformers.models.convnextv2.modeling_convnextv2 import ConvNextV2Model | |
# Copied from | |
# https://github.com/dinhanhx/velvet/blob/b70730654d26d399920964ed7e606a8f5586c9d1/velvet/collator.py#L13-L32 | |
class ImageFeatureCollator: | |
image_processor: ConvNextImageProcessor | |
image_model: ConvNextV2Model | |
def __call__(self, batch_image: List[Image.Image]): | |
return self.tensorize_batch_image(batch_image=batch_image) | |
def tensorize_batch_image(self, batch_image: List[Image.Image]): | |
image_inputs = self.image_processor(batch_image, return_tensors="pt") | |
with torch.no_grad(): | |
image_outputs = self.image_model(**image_inputs) | |
image_features = image_outputs["last_hidden_state"] | |
image_features = rearrange(image_features, "b c h w -> b h w c") | |
image_features = rearrange(image_features, "b h w c -> b (h w) c") | |
image_attentions = torch.ones(image_features.size()[:-1], dtype=torch.long) | |
return image_features, image_attentions | |
# Copied from | |
# https://github.com/dinhanhx/velvet/blob/b70730654d26d399920964ed7e606a8f5586c9d1/velvet/model/cutie.py#L6C1-L78C28 | |
class IdentityForBertEmbeddings(nn.Module): | |
"""To skip all BertEmbeddings because another text embeddings provided by another model are used""" | |
def __init__(self, *args, **kwargs) -> None: | |
super().__init__(*args, **kwargs) | |
def forward(self, **bert_embeddings_args): | |
inputs_embeds = bert_embeddings_args.get("inputs_embeds", None) | |
return inputs_embeds | |
class Cutie(nn.Module): | |
"""Cutie - Qt - Query Transformer - Q-Former | |
Cutie is motivated by the underlying theoretical foundations of Q-Former presented in BLIP-2 https://arxiv.org/abs/2301.12597 | |
It should be noted that Cutie differs from the specific approach described in the aforementioned paper | |
Both Cutie and Q-former have Query tokens. | |
Cutie uses the same unmodified BERT. | |
Q-former modifies BERT to behave differently on some tasks. | |
""" | |
def __init__( | |
self, | |
bert_config: BertConfig, | |
max_query_length: int = 32, | |
language_model_ignore_label: int = -100, | |
) -> None: | |
assert bert_config.is_decoder, "BERT must be a decoder" | |
assert bert_config.add_cross_attention, "BERT must have cross attention layer" | |
super().__init__() | |
self.bert_model = BertModel(bert_config, add_pooling_layer=False) | |
self.bert_model.embeddings = IdentityForBertEmbeddings() | |
self.query_tokens = nn.Parameter( | |
torch.zeros(1, max_query_length, bert_config.hidden_size) | |
) | |
self.query_tokens.data.normal_(mean=0.0, std=bert_config.initializer_range) | |
self.query_attentions = torch.ones( | |
self.query_tokens.size()[:-1], dtype=torch.long | |
) | |
self.query_labels = torch.full( | |
self.query_tokens.size()[:-1], language_model_ignore_label, dtype=torch.long | |
) | |
def forward( | |
self, | |
image_features: torch.Tensor, | |
image_attentions: torch.Tensor, | |
instruction_embeds: torch.Tensor, | |
instruction_attention_mask: torch.Tensor, | |
): | |
batch_size = image_features.size(0) | |
query_tokens = self.query_tokens.expand(batch_size, -1, -1).to( | |
self.query_tokens.device | |
) | |
query_attentions = self.query_attentions.expand(batch_size, -1).to( | |
self.query_tokens.device | |
) | |
cat_embeds = torch.cat([query_tokens, instruction_embeds], dim=1) | |
cat_attentions = torch.cat( | |
[query_attentions, instruction_attention_mask], dim=1 | |
) | |
bert_outputs = self.bert_model( | |
inputs_embeds=cat_embeds, | |
attention_mask=cat_attentions, | |
encoder_hidden_states=image_features, | |
encoder_attention_mask=image_attentions, | |
) | |
cutie_output = bert_outputs.last_hidden_state[:, : query_tokens.size(1), :] | |
return cutie_output | |
# Copied from | |
# https://github.com/dinhanhx/velvet/blob/b70730654d26d399920964ed7e606a8f5586c9d1/velvet/model/visual_bloom.py#L12C1-L162C31 | |
class VisualBloom(nn.Module): | |
"""A BLOOM-based model that can take image inputs""" | |
def __init__( | |
self, | |
convnextv2_config: ConvNextV2Config, | |
bert_config: BertConfig, | |
bloom_config: BloomConfig, | |
bloom_name: str, | |
use_frozen_bloom: bool = True, | |
) -> None: | |
super().__init__() | |
if ( | |
convnextv2_config.hidden_sizes[-1] | |
== bert_config.hidden_size | |
== bloom_config.hidden_size | |
): | |
self.use_projection = False | |
warnings.warn( | |
"All embedding dimensions are equal. No linear projection layers are created." | |
) | |
else: | |
self.use_projection = True | |
self.text_to_cutie = nn.Linear( | |
bloom_config.hidden_size, bert_config.hidden_size | |
) | |
self.image_to_cutie = nn.Linear( | |
convnextv2_config.hidden_sizes[-1], bert_config.hidden_size | |
) | |
self.cutie_to_text = nn.Linear( | |
bert_config.hidden_size, bloom_config.hidden_size | |
) | |
self.cutie_model = Cutie(bert_config) | |
# Load and freeze BLOOM model | |
if use_frozen_bloom: | |
self.bloom_model = BloomForCausalLM.from_pretrained(bloom_name) | |
for param in self.bloom_model.parameters(): | |
param.requires_grad = False | |
else: | |
self.bloom_model = BloomForCausalLM(bloom_config) | |
def forward( | |
self, | |
# Image model outputs - Q-former inputs | |
image_features: torch.Tensor, | |
image_attentions: torch.Tensor, | |
# Q-former inputs | |
instruction_input_ids: torch.Tensor, | |
instruction_attention_mask: torch.Tensor, | |
# Frozen language model inputs | |
language_model_input_ids: torch.Tensor, | |
language_model_attention_mask: torch.Tensor, | |
language_model_labels: torch.Tensor, | |
): | |
instruction_embeds = self.bloom_model.transformer.word_embeddings( | |
instruction_input_ids | |
) | |
instruction_embeds = self.bloom_model.transformer.word_embeddings_layernorm( | |
instruction_embeds | |
) | |
if self.use_projection: | |
image_features = self.image_to_cutie(image_features) | |
instruction_embeds = self.text_to_cutie(instruction_embeds) | |
cutie_output = self.cutie_model( | |
image_features=image_features, | |
image_attentions=image_attentions, | |
instruction_embeds=instruction_embeds, | |
instruction_attention_mask=instruction_attention_mask, | |
) | |
if self.use_projection: | |
cutie_output = self.cutie_to_text(cutie_output) | |
cutie_attentions = self.cutie_model.query_attentions.expand( | |
cutie_output.size(0), -1 | |
).to(cutie_output.device) | |
cutie_labels = self.cutie_model.query_labels.expand( | |
cutie_output.size(0), -1 | |
).to(cutie_output.device) | |
language_model_embeds = self.bloom_model.transformer.word_embeddings( | |
language_model_input_ids | |
) | |
language_model_embeds = self.bloom_model.transformer.word_embeddings_layernorm( | |
language_model_embeds | |
) | |
cat_embeds = torch.cat([cutie_output, language_model_embeds], dim=1) | |
cat_attentions = torch.cat( | |
[cutie_attentions, language_model_attention_mask], dim=1 | |
) | |
cat_labels = torch.cat([cutie_labels, language_model_labels], dim=1) | |
bloom_outputs = self.bloom_model( | |
inputs_embeds=cat_embeds, attention_mask=cat_attentions, labels=cat_labels | |
) | |
return bloom_outputs | |
def generate( | |
self, | |
# Image model outputs - Q-former inputs | |
image_features: torch.Tensor, | |
image_attentions: torch.Tensor, | |
# Q-former inputs | |
instruction_input_ids: torch.Tensor, | |
instruction_attention_mask: torch.Tensor, | |
): | |
instruction_embeds = self.bloom_model.transformer.word_embeddings( | |
instruction_input_ids | |
) | |
instruction_embeds = self.bloom_model.transformer.word_embeddings_layernorm( | |
instruction_embeds | |
) | |
if self.use_projection: | |
image_features = self.image_to_cutie(image_features) | |
cutie_instruction_embeds = self.text_to_cutie(instruction_embeds) | |
cutie_output = self.cutie_model( | |
image_features=image_features, | |
image_attentions=image_attentions, | |
instruction_embeds=cutie_instruction_embeds, | |
instruction_attention_mask=instruction_attention_mask, | |
) | |
if self.use_projection: | |
cutie_output = self.cutie_to_text(cutie_output) | |
cutie_attentions = self.cutie_model.query_attentions.expand( | |
cutie_output.size(0), -1 | |
).to(cutie_output.device) | |
cat_embeds = torch.cat([cutie_output, instruction_embeds], dim=1) | |
cat_attentions = torch.cat( | |
[cutie_attentions, instruction_attention_mask], dim=1 | |
) | |
language_output = self.bloom_model.generate( | |
inputs_embeds=cat_embeds, | |
attention_mask=cat_attentions, | |
max_length=96, | |
penalty_alpha=0.6, | |
top_k=4, | |
) | |
return language_output | |
def setup_models(visual_bloom_state_dict_path: str): | |
image_model_name = "facebook/convnextv2-large-22k-224" | |
image_config = ConvNextV2Config.from_pretrained(image_model_name) | |
image_processor = ConvNextImageProcessor.from_pretrained(image_model_name) | |
image_model = ConvNextV2Model.from_pretrained(image_model_name) | |
image_feature_collator = ImageFeatureCollator(image_processor, image_model) | |
bloom_model_name = "bigscience/bloomz-1b7" | |
bloom_config = BloomConfig.from_pretrained(bloom_model_name) | |
tokenizer = BloomTokenizerFast.from_pretrained(bloom_model_name) | |
tokenizer.padding_side = "right" | |
bert_config = BertConfig( | |
hidden_size=1024, | |
num_hidden_layers=6, | |
num_attention_heads=16, | |
is_decoder=True, | |
add_cross_attention=True, | |
) | |
visual_bloom = VisualBloom( | |
image_config, | |
bert_config, | |
bloom_config, | |
bloom_model_name, | |
use_frozen_bloom=False, | |
) | |
visual_bloom.load_state_dict(torch.load(visual_bloom_state_dict_path)) | |
visual_bloom = visual_bloom.eval() | |
return { | |
"visual_bloom": visual_bloom, | |
"tokenizer": tokenizer, | |
"image_feature_collator": image_feature_collator, | |
} | |