|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from torch.nn import CrossEntropyLoss |
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM, \ |
|
LlamaConfig, LlamaModel, LlamaForCausalLM |
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
from PIL import Image |
|
|
|
from abc import ABC, abstractmethod |
|
import os |
|
|
|
import math |
|
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig |
|
from functools import partial |
|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
from timm.models.layers import LayerNorm, LayerNorm2d |
|
from timm.models.regnet import RegStage |
|
from torch.nn import functional as F |
|
import math |
|
from einops import rearrange |
|
|
|
|
|
|
|
CONTROLLER_HEART_BEAT_EXPIRATION = 30 |
|
WORKER_HEART_BEAT_INTERVAL = 15 |
|
|
|
LOGDIR = "." |
|
|
|
|
|
IGNORE_INDEX = -100 |
|
IMAGE_TOKEN_INDEX = -200 |
|
DEFAULT_IMAGE_TOKEN = "<image>" |
|
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" |
|
DEFAULT_IM_START_TOKEN = "<im_start>" |
|
DEFAULT_IM_END_TOKEN = "<im_end>" |
|
|
|
|
|
|
|
|
|
|
|
class CLIPVisionTower(nn.Module): |
|
def __init__(self, vision_tower, args, delay_load=False): |
|
super().__init__() |
|
|
|
self.is_loaded = False |
|
|
|
self.vision_tower_name = vision_tower |
|
self.select_layer = args.mm_vision_select_layer |
|
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') |
|
|
|
if not delay_load: |
|
self.load_model() |
|
else: |
|
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) |
|
|
|
def load_model(self): |
|
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) |
|
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) |
|
self.vision_tower.requires_grad_(False) |
|
|
|
self.is_loaded = True |
|
|
|
def feature_select(self, image_forward_outs): |
|
image_features = image_forward_outs.hidden_states[self.select_layer] |
|
if self.select_feature == 'patch': |
|
image_features = image_features[:, 1:] |
|
elif self.select_feature == 'cls_patch': |
|
image_features = image_features |
|
else: |
|
raise ValueError(f'Unexpected select feature: {self.select_feature}') |
|
return image_features |
|
|
|
@torch.no_grad() |
|
def forward(self, images): |
|
if type(images) is list: |
|
image_features = [] |
|
for image in images: |
|
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) |
|
image_feature = self.feature_select(image_forward_out).to(image.dtype) |
|
image_features.append(image_feature) |
|
else: |
|
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) |
|
image_features = self.feature_select(image_forward_outs).to(images.dtype) |
|
|
|
return image_features |
|
|
|
@property |
|
def dummy_feature(self): |
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
|
@property |
|
def dtype(self): |
|
return self.vision_tower.dtype |
|
|
|
@property |
|
def device(self): |
|
return self.vision_tower.device |
|
|
|
@property |
|
def config(self): |
|
if self.is_loaded: |
|
return self.vision_tower.config |
|
else: |
|
return self.cfg_only |
|
|
|
@property |
|
def hidden_size(self): |
|
return self.config.hidden_size |
|
|
|
@property |
|
def num_patches(self): |
|
return (self.config.image_size // self.config.patch_size) ** 2 |
|
|
|
|
|
def build_vision_tower(vision_tower_cfg, **kwargs): |
|
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) |
|
is_absolute_path_exists = os.path.exists(vision_tower) |
|
|
|
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"): |
|
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) |
|
|
|
raise ValueError(f'Unknown vision tower: {vision_tower}') |
|
|
|
|
|
|
|
|
|
|
|
class HoneybeeVisualProjectorConfig(PretrainedConfig): |
|
model_type = "mllm_visual_projector" |
|
|
|
def __init__( |
|
self, |
|
projector_type: str = "resampler", |
|
hidden_size: int = 1024, |
|
num_hidden_layers: int = 6, |
|
num_attention_heads: int = 16, |
|
intermediate_size: int = 4096, |
|
attention_probs_dropout_prob: float = 0.1, |
|
initializer_range: float = 0.02, |
|
layer_norm_eps: float = 1e-6, |
|
encoder_hidden_size: int = 1024, |
|
pos_emb=False, |
|
feature_layer_index=-1, |
|
num_eos_tokens=1, |
|
use_cls=True, |
|
prenorm=False, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.projector_type = projector_type |
|
self.hidden_size = hidden_size |
|
self.num_hidden_layers = num_hidden_layers |
|
self.num_attention_heads = num_attention_heads |
|
self.intermediate_size = intermediate_size |
|
self.attention_probs_dropout_prob = attention_probs_dropout_prob |
|
self.initializer_range = initializer_range |
|
self.layer_norm_eps = layer_norm_eps |
|
self.encoder_hidden_size = encoder_hidden_size |
|
|
|
self.pos_emb = pos_emb |
|
self.feature_layer_index = feature_layer_index |
|
self.num_eos_tokens = num_eos_tokens |
|
self.use_cls = use_cls |
|
self.prenorm = prenorm |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs |
|
) -> "PretrainedConfig": |
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) |
|
|
|
|
|
if config_dict.get("model_type") == "QH_360VL": |
|
config_dict = config_dict["visual_projector_config"] |
|
|
|
|
|
return cls.from_dict(config_dict, **kwargs) |
|
|
|
def build_pos_embeds( |
|
config: HoneybeeVisualProjectorConfig, num_input_tokens: int, vision_hidden_size: int |
|
): |
|
|
|
|
|
if config.pos_emb: |
|
pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, vision_hidden_size)) |
|
nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02) |
|
else: |
|
pos_emb = None |
|
|
|
return pos_emb |
|
|
|
|
|
def build_eos_tokens(config: HoneybeeVisualProjectorConfig, output_hidden_size: int): |
|
|
|
num_eos_tokens = config.num_eos_tokens |
|
|
|
if num_eos_tokens: |
|
eos_tokens = torch.nn.Parameter(torch.randn(1, num_eos_tokens, output_hidden_size)) |
|
nn.init.trunc_normal_(eos_tokens, mean=0.0, std=config.initializer_range) |
|
else: |
|
eos_tokens = None |
|
|
|
return eos_tokens |
|
|
|
|
|
def build_prenorm(config: HoneybeeVisualProjectorConfig): |
|
|
|
if config.prenorm: |
|
prenorm = LayerNorm(config.encoder_hidden_size) |
|
else: |
|
prenorm = None |
|
return prenorm |
|
|
|
|
|
def build_mlp(depth, hidden_size, output_hidden_size): |
|
layers = [nn.Linear(hidden_size, output_hidden_size)] |
|
for _ in range(1, depth): |
|
layers.append(nn.SiLU()) |
|
layers.append(nn.Linear(output_hidden_size, output_hidden_size)) |
|
return nn.Sequential(*layers) |
|
|
|
def get_abs_pos(abs_pos, tgt_size): |
|
|
|
|
|
|
|
|
|
src_size = int(math.sqrt(abs_pos.size(1))) |
|
|
|
tgt_size = int(math.sqrt(tgt_size)) |
|
dtype = abs_pos.dtype |
|
|
|
if src_size != tgt_size: |
|
return F.interpolate( |
|
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), |
|
size=(tgt_size, tgt_size), |
|
mode="bicubic", |
|
align_corners=False, |
|
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) |
|
else: |
|
return abs_pos |
|
|
|
|
|
class Projector(nn.Module): |
|
"""Base projector class""" |
|
|
|
def __init__( |
|
self, |
|
config: HoneybeeVisualProjectorConfig, |
|
num_input_tokens: int, |
|
output_hidden_size: int, |
|
): |
|
super().__init__() |
|
self.config = config |
|
self.num_input_tokens = num_input_tokens |
|
self.output_hidden_size = output_hidden_size |
|
|
|
|
|
self.eos_tokens = build_eos_tokens(config, output_hidden_size) |
|
|
|
|
|
self.pos_emb = build_pos_embeds(config, num_input_tokens, config.encoder_hidden_size) |
|
|
|
self.prenorm = build_prenorm(config) |
|
|
|
self.build_net() |
|
|
|
def build_net(self): |
|
raise NotImplementedError() |
|
|
|
def _forward(self, x): |
|
raise NotImplementedError() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
x: (B, L, encoder_hidden_size) tensor from the visual backbone (CLIP visual encoder), including cls token. |
|
""" |
|
if self.prenorm is not None: |
|
x = self.prenorm(x) |
|
|
|
if self.pos_emb is not None: |
|
|
|
pos_emb = get_abs_pos(self.pos_emb[:,1:], x.size(1)) |
|
pos_emb = pos_emb.to(device=x.device) |
|
x += pos_emb |
|
|
|
x = self._forward(x) |
|
|
|
B = x.size(0) |
|
if self.eos_tokens is not None: |
|
x = torch.cat([x, self.eos_tokens.expand(B, -1, -1)], dim=1) |
|
return x |
|
|
|
|
|
class ConvProjector(Projector): |
|
def _forward(self, x): |
|
|
|
|
|
|
|
hw = int(x.size(1) ** 0.5) |
|
x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw) |
|
x = self.net(x) |
|
x = rearrange(x, "b d h w -> b (h w) d") |
|
x = self.readout(x) |
|
|
|
return x |
|
|
|
|
|
class CAbstractor(ConvProjector): |
|
"""C-Abstractor""" |
|
def build_net(self): |
|
encoder_hidden_size = self.config.encoder_hidden_size |
|
hidden_size = self.config.hidden_size |
|
output_hidden_size = self.output_hidden_size |
|
depth = self.config.depth |
|
mlp_depth = self.config.mlp_depth |
|
|
|
n_queries = self.config.num_queries |
|
assert (n_queries ** 0.5).is_integer(), "n_queries must be square number" |
|
hw = int(n_queries ** 0.5) |
|
|
|
|
|
RegBlock = partial( |
|
RegStage, |
|
stride=1, |
|
dilation=1, |
|
act_layer=nn.SiLU, |
|
norm_layer=LayerNorm2d, |
|
) |
|
|
|
s1 = RegBlock( |
|
depth, |
|
encoder_hidden_size, |
|
hidden_size, |
|
) |
|
sampler = nn.AdaptiveAvgPool2d((hw, hw)) |
|
s2 = RegBlock( |
|
depth, |
|
hidden_size, |
|
hidden_size, |
|
) |
|
|
|
self.net = nn.Sequential(s1, sampler, s2) |
|
self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size) |
|
|
|
class IdentityMap(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x, *args, **kwargs): |
|
return x |
|
|
|
@property |
|
def config(self): |
|
return {"mm_projector_type": 'identity'} |
|
|
|
|
|
class SimpleResBlock(nn.Module): |
|
def __init__(self, channels): |
|
super().__init__() |
|
self.pre_norm = nn.LayerNorm(channels) |
|
|
|
self.proj = nn.Sequential( |
|
nn.Linear(channels, channels), |
|
nn.GELU(), |
|
nn.Linear(channels, channels) |
|
) |
|
def forward(self, x): |
|
x = self.pre_norm(x) |
|
return x + self.proj(x) |
|
|
|
|
|
def build_honeybee_projector(config, projector_type, num_tokens,lm_hidden_size): |
|
"""Build projector (abstractor) and query_tokens (optionally for resampler)""" |
|
proj_config = config |
|
proj_type = projector_type |
|
num_tokens = num_tokens |
|
output_hidden_size = lm_hidden_size |
|
|
|
abstractor = { |
|
"c-abs": CAbstractor, |
|
}[ |
|
proj_type |
|
](proj_config, num_tokens, output_hidden_size) |
|
return abstractor |
|
|
|
|
|
def build_vision_projector(config, delay_load=False, **kwargs): |
|
projector_type = getattr(config, 'mm_projector_type', 'linear') |
|
|
|
if projector_type == 'linear': |
|
return nn.Linear(config.mm_hidden_size, config.hidden_size) |
|
|
|
if projector_type == 'c-abs': |
|
|
|
local_config_path = config.mm_projector_config |
|
honeybee_config = HoneybeeVisualProjectorConfig.from_pretrained(local_config_path) |
|
|
|
num_tokens = config.mm_num_tokens |
|
|
|
lm_hidden_size = config.hidden_size |
|
|
|
abstractor = build_honeybee_projector(honeybee_config,projector_type,num_tokens,lm_hidden_size) |
|
return abstractor |
|
|
|
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) |
|
if mlp_gelu_match: |
|
mlp_depth = int(mlp_gelu_match.group(1)) |
|
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] |
|
for _ in range(1, mlp_depth): |
|
modules.append(nn.GELU()) |
|
modules.append(nn.Linear(config.hidden_size, config.hidden_size)) |
|
return nn.Sequential(*modules) |
|
|
|
if projector_type == 'identity': |
|
return IdentityMap() |
|
|
|
raise ValueError(f'Unknown projector type: {projector_type}') |
|
|
|
|
|
|
|
|
|
class QH360_VL_MetaModel: |
|
|
|
def __init__(self, config): |
|
super(QH360_VL_MetaModel, self).__init__(config) |
|
if hasattr(config, "mm_vision_tower"): |
|
self.vision_tower = build_vision_tower(config, delay_load=True) |
|
self.mm_projector_ctt = build_vision_projector(config) |
|
self.mm_projector_ori = build_vision_projector(config) |
|
|
|
|
|
|
|
def get_vision_tower(self): |
|
vision_tower = getattr(self, 'vision_tower', None) |
|
if type(vision_tower) is list: |
|
vision_tower = vision_tower[0] |
|
return vision_tower |
|
|
|
|
|
class QH360_VL_MetaForCausalLM(ABC): |
|
|
|
@abstractmethod |
|
def get_model(self): |
|
pass |
|
|
|
def get_vision_tower(self): |
|
return self.get_model().get_vision_tower() |
|
|
|
def encode_images(self, images): |
|
image_features = self.get_model().get_vision_tower()(images) |
|
image_features = self.get_model().mm_projector(image_features) |
|
return image_features |
|
|
|
def encode_images_noprojector(self, images): |
|
image_features = self.get_model().get_vision_tower()(images) |
|
image_features = image_features.detach() |
|
return image_features |
|
|
|
def prepare_inputs_labels_for_multimodal( |
|
self, input_ids, attention_mask, past_key_values, labels, images |
|
): |
|
vision_tower = self.get_vision_tower() |
|
if vision_tower is None or images is None or input_ids.shape[1] == 1: |
|
if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: |
|
attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) |
|
return input_ids, attention_mask, past_key_values, None, labels |
|
|
|
if type(images) is list or images.ndim == 5: |
|
image_features = [] |
|
for image in images: |
|
if image.ndim == 3: |
|
image_features.append(self.encode_images(image.unsqueeze(0)).squeeze(0)) |
|
elif image.ndim == 4: |
|
|
|
temp_feats = self.encode_images_noprojector(image) |
|
src_size = int(math.sqrt(temp_feats.shape[1])) |
|
temp_feats = temp_feats.reshape(temp_feats.shape[0]//5,5,-1, temp_feats.shape[-1]) |
|
x1 = temp_feats[:,4,:,:] |
|
x = temp_feats[:,:4,:,:] |
|
x = x.reshape(x.shape[0], -1, src_size, src_size, x.shape[-1]) |
|
x = x.transpose(1,2).reshape(x.shape[0], src_size,2,2, src_size, x.shape[-1]) |
|
x = x.transpose(1,2).reshape(x.shape[0], -1, x.shape[-1]) |
|
x1 = self.get_model().mm_projector_ori(x1).squeeze(0) |
|
x = self.get_model().mm_projector_ctt(x).squeeze(0) |
|
temp_feats_all = torch.cat([x,x1],dim=0) |
|
image_features.append(temp_feats_all) |
|
else: |
|
image_features = self.encode_images(images) |
|
|
|
|
|
new_input_embeds = [] |
|
new_labels = [] if labels is not None else None |
|
cur_image_idx = 0 |
|
for batch_idx, cur_input_ids in enumerate(input_ids): |
|
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: |
|
|
|
|
|
half_len = cur_input_ids.shape[0] // 2 |
|
cur_image_features = image_features[cur_image_idx] |
|
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len]) |
|
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:]) |
|
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0) |
|
new_input_embeds.append(cur_input_embeds) |
|
if labels is not None: |
|
new_labels.append(labels[batch_idx]) |
|
cur_image_idx += 1 |
|
continue |
|
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] |
|
cur_new_input_embeds = [] |
|
if labels is not None: |
|
cur_labels = labels[batch_idx] |
|
cur_new_labels = [] |
|
assert cur_labels.shape == cur_input_ids.shape |
|
while image_token_indices.numel() > 0: |
|
cur_image_features = image_features[cur_image_idx] |
|
image_token_start = image_token_indices[0] |
|
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): |
|
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach()) |
|
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start])) |
|
cur_new_input_embeds.append(cur_image_features) |
|
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2])) |
|
if labels is not None: |
|
cur_new_labels.append(cur_labels[:image_token_start]) |
|
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) |
|
cur_new_labels.append(cur_labels[image_token_start:image_token_start+1]) |
|
cur_labels = cur_labels[image_token_start+2:] |
|
else: |
|
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start])) |
|
cur_new_input_embeds.append(cur_image_features) |
|
if labels is not None: |
|
cur_new_labels.append(cur_labels[:image_token_start]) |
|
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) |
|
cur_labels = cur_labels[image_token_start+1:] |
|
cur_image_idx += 1 |
|
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): |
|
cur_input_ids = cur_input_ids[image_token_start+2:] |
|
else: |
|
cur_input_ids = cur_input_ids[image_token_start+1:] |
|
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] |
|
if cur_input_ids.numel() > 0: |
|
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): |
|
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach()) |
|
else: |
|
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) |
|
if labels is not None: |
|
cur_new_labels.append(cur_labels) |
|
cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] |
|
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) |
|
new_input_embeds.append(cur_new_input_embeds) |
|
if labels is not None: |
|
cur_new_labels = torch.cat(cur_new_labels, dim=0) |
|
new_labels.append(cur_new_labels) |
|
|
|
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): |
|
max_len = max(x.shape[0] for x in new_input_embeds) |
|
|
|
new_input_embeds_align = [] |
|
for cur_new_embed in new_input_embeds: |
|
cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) |
|
new_input_embeds_align.append(cur_new_embed) |
|
new_input_embeds = torch.stack(new_input_embeds_align, dim=0) |
|
|
|
if labels is not None: |
|
new_labels_align = [] |
|
_new_labels = new_labels |
|
for cur_new_label in new_labels: |
|
cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) |
|
new_labels_align.append(cur_new_label) |
|
new_labels = torch.stack(new_labels_align, dim=0) |
|
|
|
if attention_mask is not None: |
|
new_attention_mask = [] |
|
for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): |
|
new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) |
|
new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) |
|
cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) |
|
new_attention_mask.append(cur_new_attention_mask) |
|
attention_mask = torch.stack(new_attention_mask, dim=0) |
|
assert attention_mask.shape == new_labels.shape |
|
else: |
|
new_input_embeds = torch.stack(new_input_embeds, dim=0) |
|
if labels is not None: |
|
new_labels = torch.stack(new_labels, dim=0) |
|
|
|
if attention_mask is not None: |
|
new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) |
|
attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) |
|
assert attention_mask.shape == new_input_embeds.shape[:2] |
|
|
|
return None, attention_mask, past_key_values, new_input_embeds, new_labels |
|
|
|
|
|
|
|
class QH360_VLConfig(LlamaConfig): |
|
model_type = "QH_360VL" |
|
|
|
|
|
class QH360_VL_LlamaModel(QH360_VL_MetaModel, LlamaModel): |
|
config_class = QH360_VLConfig |
|
|
|
def __init__(self, config: LlamaConfig): |
|
super(QH360_VL_LlamaModel, self).__init__(config) |
|
|
|
|
|
class QH360_VL_LlamaForCausalLM(LlamaForCausalLM, QH360_VL_MetaForCausalLM): |
|
config_class = QH360_VLConfig |
|
|
|
def __init__(self, config): |
|
super(LlamaForCausalLM, self).__init__(config) |
|
config._attn_implementation == "flash_attention_2" |
|
self.model = QH360_VL_LlamaModel(config) |
|
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_model(self): |
|
return self.model |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
images: Optional[torch.FloatTensor] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) |
|
|
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict |
|
) |
|
|
|
hidden_states = outputs[0] |
|
logits = self.lm_head(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = CrossEntropyLoss() |
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def prepare_inputs_for_generation( |
|
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
|
): |
|
if past_key_values: |
|
input_ids = input_ids[:, -1:] |
|
|
|
|
|
if inputs_embeds is not None and past_key_values is None: |
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
else: |
|
model_inputs = {"input_ids": input_ids} |
|
|
|
model_inputs.update( |
|
{ |
|
"past_key_values": past_key_values, |
|
"use_cache": kwargs.get("use_cache"), |
|
"attention_mask": attention_mask, |
|
"images": kwargs.get("images", None), |
|
} |
|
) |
|
return model_inputs |
|
|
|
def build_conversation_input_ids( |
|
self, |
|
tokenizer: "PreTrainedTokenizer", |
|
query: str, |
|
image = None, |
|
image_processor=None, |
|
): |
|
|
|
input_msg = [ |
|
{ |
|
"role": "system", |
|
"content": "You are a multilingual, helpful, respectful and honest assistant who can respond in the same language, depending on the language of the question. Try to be as helpful as possible while still being safe. Your answer should not contain anything that is false, unhealthy, harmful, immoral, racist, sexist, toxic, dangerous, or illegal, and if the question relates to such content, please decline to answer. Make sure your answer is socially fair and positive. If a question doesn't make any sense, or is inconsistent with the facts, explain why instead of answering the wrong answer. If you don't know the answer to a question, don't share false information." |
|
}, |
|
{ |
|
"role": "user", |
|
"content": "<|reserved_special_token_44|>"+ '\n' + query |
|
} |
|
] |
|
|
|
input_ids = tokenizer.apply_chat_template( |
|
input_msg, |
|
add_generation_prompt=True, |
|
padding="longest", |
|
return_tensors="pt", |
|
) |
|
input_id_list = input_ids[0].tolist() |
|
input_id_list[input_id_list.index(128049)]=-200 |
|
input_ids = torch.tensor(input_id_list, dtype=input_ids.dtype,device=input_ids.device) |
|
input_ids = input_ids.unsqueeze(0) |
|
image_tensor = self.process_images_slid_window(image,image_processor).unsqueeze(0) |
|
|
|
return { |
|
'input_ids': input_ids, |
|
'image': image_tensor, |
|
} |
|
|
|
|
|
|
|
def process_images_slid_window(self, image, image_processor, vit_is=336): |
|
|
|
def get_proper_imgsize(pil_img, vit_is): |
|
max_w_h = vit_is * 2 |
|
new_pil_img = pil_img.resize((max_w_h, max_w_h)) |
|
return new_pil_img |
|
|
|
def tensor_crop(tensor_array, left, upper, right, lower): |
|
|
|
return tensor_array[:, upper:lower, left:right] |
|
|
|
def image_slid_window(image, num_slid_window): |
|
|
|
|
|
if num_slid_window == 5: |
|
image_x2, image_x1 = image[0], image[1] |
|
vit_is = image_x1.shape[1] |
|
h, w = image_x2.shape[1],image_x2.shape[2] |
|
image0 = tensor_crop(image_x2, 0, 0, vit_is, vit_is) |
|
image1 = tensor_crop(image_x2, w-vit_is, 0, w, vit_is) |
|
image2 = tensor_crop(image_x2, 0, h-vit_is, vit_is, h) |
|
image3 = tensor_crop(image_x2, w-vit_is, h-vit_is, w, h) |
|
return torch.stack([image0, image1, image2, image3, image_x1]) |
|
else: |
|
return image |
|
|
|
def expand2square(pil_img, background_color): |
|
width, height = pil_img.size |
|
if width == height: |
|
return pil_img |
|
elif width > height: |
|
result = Image.new(pil_img.mode, (width, width), background_color) |
|
result.paste(pil_img, (0, (width - height) // 2)) |
|
return result |
|
else: |
|
result = Image.new(pil_img.mode, (height, height), background_color) |
|
result.paste(pil_img, ((height - width) // 2, 0)) |
|
return result |
|
|
|
vit_is = vit_is |
|
|
|
num_slid_window = 5 |
|
|
|
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) |
|
image = get_proper_imgsize(image, vit_is) |
|
image_x2 = image_processor.preprocess(image, return_tensors='pt', do_resize=False, do_center_crop=False)['pixel_values'][0] |
|
image_x1 = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] |
|
image = [image_x2, image_x1] |
|
image = image_slid_window(image, num_slid_window) |
|
|
|
return image |
|
|