Pisces / llava /model /llava_arch.py
jiuhai's picture
first
99aee7a
raw
history blame
23 kB
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
from .multimodal_encoder.builder import build_vision_tower, build_gen_vision_tower
from .multimodal_projector.builder import build_vision_projector, build_down_projector, build_gen_vision_projector
from llava.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX, DEFAULT_IM_END_TOKEN_IDX
class LlavaMetaModel:
def __init__(self, config):
super(LlavaMetaModel, self).__init__(config)
if hasattr(config, "mm_vision_tower"):
self.vision_tower = build_vision_tower(config, delay_load=True)
self.mm_projector = build_vision_projector(config)
self.down_projector = build_down_projector(config)
if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
self.image_newline = nn.Parameter(
torch.empty(config.hidden_size, dtype=self.dtype)
)
if hasattr(config, "gen_vision_tower"):
self.gen_vision_tower = build_gen_vision_tower(config, delay_load=True)
self.gen_projector = build_gen_vision_projector(config)
if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
self.image_newline = nn.Parameter(
torch.empty(config.hidden_size, dtype=self.dtype)
)
def get_vision_tower(self):
vision_tower = getattr(self, 'vision_tower', None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower
def get_gen_vision_tower(self):
gen_vision_tower = getattr(self, 'gen_vision_tower', None)
if type(gen_vision_tower) is list:
gen_vision_tower = gen_vision_tower[0]
return gen_vision_tower
def initialize_vision_modules(self, model_args, fsdp=None):
vision_tower = model_args.vision_tower
gen_vision_tower = model_args.gen_vision_tower
mm_vision_select_layer = model_args.mm_vision_select_layer
mm_vision_select_feature = model_args.mm_vision_select_feature
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
pretrain_gen_mlp_adapter = model_args.pretrain_gen_mlp_adapter
mm_patch_merge_type = model_args.mm_patch_merge_type
self.config.mm_vision_tower = vision_tower
self.config.gen_vision_tower = gen_vision_tower
self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
if self.get_vision_tower() is None:
vision_tower = build_vision_tower(model_args)
if fsdp is not None and len(fsdp) > 0:
self.vision_tower = [vision_tower]
else:
self.vision_tower = vision_tower
else:
if fsdp is not None and len(fsdp) > 0:
vision_tower = self.vision_tower[0]
else:
vision_tower = self.vision_tower
vision_tower.load_model()
if self.get_gen_vision_tower() is None:
gen_vision_tower = build_gen_vision_tower(model_args)
if fsdp is not None and len(fsdp) > 0:
self.gen_vision_tower = [gen_vision_tower]
else:
self.gen_vision_tower = gen_vision_tower
else:
if fsdp is not None and len(fsdp) > 0:
gen_vision_tower = self.gen_vision_tower[0]
else:
gen_vision_tower = self.gen_vision_tower
gen_vision_tower.load_model()
self.config.use_mm_proj = True
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
self.config.gen_projector_type = getattr(model_args, 'gen_projector_type', 'linear')
self.config.mm_hidden_size = vision_tower.hidden_size
self.config.gen_hidden_size = gen_vision_tower.hidden_size
self.config.mm_vision_select_layer = mm_vision_select_layer
self.config.mm_vision_select_feature = mm_vision_select_feature
self.config.mm_patch_merge_type = mm_patch_merge_type
self.config.n_query = model_args.n_query
self.config.gen_pooling = model_args.gen_pooling
if getattr(self, 'mm_projector', None) is None:
print("random initiation the mm_project !!!")
self.mm_projector = build_vision_projector(self.config)
if 'unpad' in mm_patch_merge_type:
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
self.image_newline = nn.Parameter(
torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
)
else:
# In case it is frozen by LoRA
for p in self.mm_projector.parameters():
p.requires_grad = True
if getattr(self, 'gen_projector', None) is None:
print("random initiation the gen_projector !!!")
self.gen_projector = build_gen_vision_projector(self.config)
if 'unpad' in mm_patch_merge_type:
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
self.image_newline = nn.Parameter(
torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
)
else:
# In case it is frozen by LoRA
for p in self.gen_projector.parameters():
p.requires_grad = True
if getattr(self, 'down_projector', None) is None:
print("random initiation the down_projector !!!")
self.down_projector = build_down_projector(self.config)
else:
# In case it is frozen by LoRA
for p in self.down_projector.parameters():
p.requires_grad = True
if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
if pretrain_gen_mlp_adapter is not None:
gen_projector_weights = torch.load(pretrain_gen_mlp_adapter, map_location='cpu')
def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
self.gen_projector.load_state_dict(get_w(gen_projector_weights, 'mm_projector'))
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of PIL image (width, height).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding:current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding:current_width - padding]
return unpadded_tensor
class LlavaMetaForCausalLM(ABC):
@abstractmethod
def get_model(self):
pass
def get_vision_tower(self):
return self.get_model().get_vision_tower()
def get_gen_vision_tower(self):
return self.get_model().get_gen_vision_tower()
def encode_images(self, images):
device = self.get_vision_tower().device
images = images.to(device)
image_features = self.get_model().get_vision_tower()(images)
num_img, _, c = image_features.shape
gen_pooling = self.get_gen_pooling()
n_query = self.get_n_query() if not 'early' in gen_pooling else 729
if 'pool2d' in gen_pooling:
stride = int(gen_pooling.split('_')[-1])
sqrt_n = int(n_query**0.5)
image_features = image_features.permute(0, 2, 1).view(num_img, -1, sqrt_n, sqrt_n)
image_features = F.avg_pool2d(image_features, kernel_size=(stride, stride), stride=stride)
image_features = image_features.reshape(num_img, c, -1).permute(0,2,1)
# image_features = image_features.contiguous().view(-1, c)
# image_features = self.get_model().mm_projector(image_features)
return image_features
def get_mm_projector(self):
return self.get_model().mm_projector
def get_gen_projector(self):
return self.get_model().gen_projector
def get_n_query(self):
return self.get_model().config.n_query
def get_gen_pooling(self):
return self.get_model().config.gen_pooling
def pool_img(self, image_features):
num_img, n, c = image_features.shape
gen_pooling = self.get_gen_pooling()
# n_query = self.get_n_query()
stride = int(gen_pooling.split('_')[-1])
sqrt_n = int(n**0.5)
image_features = image_features.permute(0, 2, 1).view(num_img, c, sqrt_n, sqrt_n)
image_features = F.avg_pool2d(image_features, kernel_size=(stride, stride), stride=stride)
image_features = image_features.view(num_img, c, -1).permute(0,2,1).contiguous()
return image_features
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels,
gen_images, und_images, image_sizes=None
):
vision_tower = self.get_vision_tower()
mm_projector = self.get_mm_projector()
gen_vision_tower = self.get_gen_vision_tower()
gen_projector = self.get_gen_projector()
if (gen_images is None and und_images is None) or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels, None, None, None
if not gen_images is None:
# print(f"gen_images {gen_images.shape}")
prompt_image_embeds = gen_vision_tower(gen_images) # TODO: check dimension
# print(f"prompt_image_embeds {prompt_image_embeds.shape}")
if 'early' in self.get_gen_pooling():
prompt_image_embeds = self.pool_img(prompt_image_embeds)
num_img, _, c = prompt_image_embeds.shape # [batch, 729, 1152]
# all_image_embeds = torch.clone(prompt_image_embeds).detach()
prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c)
target_image_embeds = torch.clone(prompt_image_embeds).detach()
prompt_image_embeds = gen_projector(prompt_image_embeds)
else:
# print(f"warning !!!!!!!!!!!!!")
target_image_embeds = None
# quick fix
# change und_images dim so gen_vision_tower process
# und_images torch.Size([2, 3, 336, 336])
# gen_images torch.Size([2, 3, 384, 384])
num_img = und_images.shape[0]
dummy = torch.zeros(num_img, 3, 448, 448 , dtype=und_images.dtype, device=und_images.device) # TODO
temp = gen_vision_tower(dummy)[:,:729,:]
num_img, _, c = temp.shape
temp = temp.contiguous().view(-1, c)
temp = gen_projector(temp) * 1e-9
# print(f"gen temp {temp.sum()}")
if not und_images is None:
# print(f"und_images {und_images.shape}")
und_image_embeds = vision_tower(und_images)
num_img, _, c = und_image_embeds.shape
und_image_embeds = und_image_embeds.contiguous().view(-1, c)
und_image_embeds = mm_projector(und_image_embeds)
if gen_images is None:
und_image_embeds += temp
else:
# print(f"warning !!!!!!!!!!!!!")
num_img = gen_images.shape[0]
dummy = torch.zeros(num_img, 3, 384, 384 , dtype=gen_images.dtype, device=gen_images.device) # clip (3, 336, 336)
temp = vision_tower(dummy)
if 'early' in self.get_gen_pooling():
temp = temp[:,:64,:]
num_img, _, c = temp.shape
temp = temp.contiguous().view(-1, c)
temp = mm_projector(temp) * 1e-9
# print(f"und temp {temp.sum()}")
prompt_image_embeds += temp
image_idx = (input_ids == IMAGE_TOKEN_IDX)
img_indicator = torch.clone(image_idx)
output_indicator = labels != -100
# print(f"### output_indicator {output_indicator.tolist()}")
input_indicator = labels == -100
# print(f"### input_indicator {input_indicator.tolist()}")
# print(f"output_indicator {output_indicator[0]}")
img_loss_indicator = torch.logical_and(output_indicator, img_indicator)
img_loss_indicator = torch.cat(
[img_loss_indicator[:, 1:], img_loss_indicator[:, :1]], dim=1)
img_indicator = torch.cat(
[img_indicator[:, 1:], img_indicator[:, :1]], dim=1)
# num_output_img = img_loss_indicator.sum().item()//self.model.n_query
# print(f"img_loss_indicator {img_loss_indicator[0]}")
# print(f"img_loss_indicator.sum() {img_loss_indicator.sum()}")
if not target_image_embeds is None:
target_image_embeds = target_image_embeds[-img_loss_indicator.sum():,:]
# print(f"target_image_embeds {target_image_embeds}")
# print(f"before embed input ids")
# print(f"image_idx.sum() {image_idx.sum()}")
# print(f"input_ids {input_ids[0,:]}")
# print(f"self.model.decoder.lm.model.emb {self.model.decoder.lm.get_input_embeddings().weight.data.shape}")
text_embeds = self.get_model().embed_tokens(input_ids)
# print(f"text_embeds {text_embeds}")
# print(f"break 1")
N_QUERY = self.get_n_query()
# if not image_idx.sum()/N_QUERY == image_idx.sum()//N_QUERY:
# print('warning half image: ', image_idx.sum()/N_QUERY, image_idx.sum()//N_QUERY)
# breakpoint()
# print(f"image_idx {image_idx}")
# print(f"text_embeds {text_embeds}, prompt_image_embeds {prompt_image_embeds}")
# print(f"prompt_image_embeds {prompt_image_embeds}")
gen_img_idx = torch.logical_and(output_indicator, image_idx)
if not target_image_embeds is None:
text_embeds[gen_img_idx] = prompt_image_embeds.to(text_embeds.device)[:gen_img_idx.sum(),:]
target_image_embeds = target_image_embeds.to(text_embeds.device)[:gen_img_idx.sum(),:]
und_img_idx = torch.logical_and(input_indicator, image_idx)
if not und_images is None:
# text_embeds[und_img_idx] = und_image_embeds.to(text_embeds.device)[:und_img_idx.sum(),:]
# try:
text_embeds[und_img_idx] = und_image_embeds.to(text_embeds.device)[:und_img_idx.sum(), :]
# except RuntimeError as e:
# print(f"RuntimeError: {e}")
# print(f"text_embeds shape: {text_embeds.shape}")
# print(f"und_images: {und_images.shape}")
# print(f"und_image_embeds shape: {und_image_embeds.shape}")
# print(f"und_img_idx sum: {und_img_idx.sum()} (should match number of rows in und_image_embeds)")
# print("Continuing without modifying text_embeds.")
# # Get the shapes involved
# expected_shape = und_img_idx.sum() # Number of True values or indices
# actual_shape = und_image_embeds.shape[0] # Number of rows in und_image_embeds
# if expected_shape > actual_shape:
# # If more indices than embeddings, truncate und_img_idx to match und_image_embeds
# print(f"Shape mismatch: expected {expected_shape} rows, but only {actual_shape} embeddings available.")
# adjusted_idx = und_img_idx.nonzero(as_tuple=True)[0][:actual_shape] # Get the first `actual_shape` indices
# text_embeds[adjusted_idx] = und_image_embeds.to(text_embeds.device)
# print(f"Truncated indices from {expected_shape} to {actual_shape}.")
# else:
# # If more embeddings than indices, trim und_image_embeds to match und_img_idx
# print(f"Shape mismatch: expected {expected_shape} rows, but got {actual_shape}. Using first {expected_shape} embeddings.")
# text_embeds[und_img_idx] = und_image_embeds[:expected_shape, :].to(text_embeds.device)
# print(f"target_image_embeds {target_image_embeds}")
# print(f"break 4")
labels[image_idx] = -100
# print(f"labels[0] {labels[0]}")
# print(f"break 5")
# print({'all_image_embeds':all_image_embeds.shape, 'num_output_img':num_output_img, 'num_img': num_img})
return None, position_ids, attention_mask, past_key_values, text_embeds, labels, img_loss_indicator, img_indicator, target_image_embeds
def prepare_inputs_labels_for_understanding(
self, input_ids, position_ids, attention_mask, past_key_values, labels,
batch_images, image_sizes=None
):
vision_tower = self.get_vision_tower()
mm_projector = self.get_mm_projector()
# pdb.set_trace()
prompt_image_embeds = vision_tower(batch_images) # TODO: check dimension
# print(f"prompt_image_embeds.shape: {prompt_image_embeds.shape}")
num_img, _, c = prompt_image_embeds.shape # [batch, 576, 1024]
all_image_embeds = torch.clone(prompt_image_embeds).detach()
prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c)
prompt_image_embeds = mm_projector(prompt_image_embeds)
# print(f"prompt_image_embeds {prompt_image_embeds.shape}")
# print(f"input_ids {input_ids}")
# IMAGE = 128259
image_idx = (input_ids == IMAGE_TOKEN_IDX)
# print(f"image_idx {image_idx[0]}")
img_indicator = torch.clone(image_idx)
img_indicator = torch.cat(
[img_indicator[:, 1:], img_indicator[:, :1]], dim=1)
# print(f"before embed input ids")
# print(f"image_idx.sum() {image_idx.sum()}")
# print(f"input_ids {input_ids[0,:]}")
# print(f"self.model.decoder.lm.model.emb {self.model.decoder.lm.get_input_embeddings().weight.data.shape}")
text_embeds = self.get_model().embed_tokens(input_ids)
# print(f"text_embeds {text_embeds}")
# print(f"break 1")
N_QUERY = self.get_n_query()
# if not image_idx.sum()/N_QUERY == image_idx.sum()//N_QUERY:
# print('warning half image: ', image_idx.sum()/N_QUERY, image_idx.sum()//N_QUERY)
# print(f"break 1.5")
# print(f"image_idx {image_idx}")
# print(f"text_embeds {text_embeds}, prompt_image_embeds {prompt_image_embeds}")
text_embeds[image_idx] = prompt_image_embeds.to(text_embeds.device)[:image_idx.sum(),:]
# print({'all_image_embeds':all_image_embeds.shape, 'num_output_img':num_output_img, 'num_img': num_img})
return None, position_ids, attention_mask, past_key_values, text_embeds, img_indicator, labels
def initialize_vision_tokenizer(self, model_args, tokenizer):
if model_args.mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if model_args.mm_use_im_start_end:
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
assert num_new_tokens == 2
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
elif model_args.mm_use_im_patch_token:
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = False
for p in self.get_output_embeddings().parameters():
p.requires_grad = False