|
from typing import List, Optional, Tuple, Union |
|
import warnings, os, torch |
|
import torch.nn as nn |
|
|
|
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, AutoModelForCausalLM, AutoTokenizer |
|
from transformers.modeling_utils import ContextManagers, no_init_weights |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers.generation.utils import GenerateOutput |
|
from .configuration_apollo import ApolloConfig |
|
|
|
from .vision_tower import ApolloVisionTower |
|
from .mm_connector import MMConnector |
|
|
|
IGNORE_INDEX = -100 |
|
X_TOKEN_INDEX = -200 |
|
|
|
|
|
def get_model_config(config): |
|
default_keys = ["llm_cfg", "vision_tower_cfg", "mm_connector_cfg"] |
|
if hasattr(config, "_name_or_path") and len(config._name_or_path) >= 2: |
|
root_path = config._name_or_path |
|
else: |
|
root_path = config.resume_path |
|
|
|
return_pths = [] |
|
for key in default_keys: |
|
cfg = getattr(config, key, None) |
|
if isinstance(cfg, dict): |
|
try: |
|
return_pths.append(os.path.join(root_path, key[:-4])) |
|
except: |
|
raise ValueError(f"Cannot find resume path in config for {key}!") |
|
elif isinstance(cfg, PretrainedConfig): |
|
return_pths.append(os.path.join(root_path, key[:-4])) |
|
elif isinstance(cfg, str): |
|
return_pths.append(cfg) |
|
|
|
return_list = [] |
|
for pth in return_pths: |
|
return_list.append(AutoConfig.from_pretrained(pth, trust_remote_code=True)) |
|
|
|
return return_list |
|
|
|
|
|
def build_llm_and_tokenizer( |
|
llm_cfg: str, |
|
config: PretrainedConfig, |
|
attn_implementation=None, |
|
model_max_length=None, |
|
*args, |
|
**kwargs, |
|
) -> PreTrainedModel: |
|
llm_arch = getattr(llm_cfg, "architectures")[0].lower() |
|
|
|
llm_path = llm_cfg._name_or_path |
|
llm = AutoModelForCausalLM.from_pretrained( |
|
llm_path, config=llm_cfg, torch_dtype=eval(config.model_dtype), *args, **kwargs |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
llm_path, |
|
model_max_length=llm_cfg.model_max_length, |
|
padding_side="right", |
|
use_fast=False, |
|
legacy=False, |
|
**kwargs |
|
) |
|
|
|
|
|
return llm, tokenizer |
|
|
|
|
|
class ApolloForCausalLM(PreTrainedModel): |
|
def __init__(self, config: ApolloConfig, *args, **kwargs): |
|
super().__init__(config) |
|
llm_cfg, vision_tower_cfg, mm_connector_cfg = get_model_config(config) |
|
model_dtype = getattr(config, "model_dtype", "torch.float16") |
|
if not hasattr(config, "model_dtype"): |
|
warnings.warn("model_dtype not found in config, defaulting to torch.float16.") |
|
config.model_dtype = model_dtype |
|
|
|
|
|
self.lm_head = nn.Linear(llm_cfg.hidden_size, config.vocab_size, bias=False) |
|
self.vision_tower = ApolloVisionTower(config, vision_tower_cfg) |
|
self.mm_connector = MMConnector.from_pretrained(mm_connector_cfg._name_or_path) |
|
self.llm, self.tokenizer = build_llm_and_tokenizer(llm_cfg, config, *args, **kwargs) |
|
self.post_init() |
|
self.is_loaded = True |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
vision_input: Optional[List[torch.FloatTensor]] = None, |
|
data_types: Optional[List[str]] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position=None, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
if inputs_embeds is None: |
|
( |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
inputs_embeds, |
|
labels |
|
) = self.prepare_inputs_labels_for_multimodal( |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
labels, |
|
vision_input, |
|
data_types |
|
) |
|
|
|
return self.get_llm().forward( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
labels=labels, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
inputs: Optional[torch.Tensor] = None, |
|
vision_input: Optional[List[torch.Tensor]] = None, |
|
data_types: Optional[List[str]] = None, |
|
**kwargs, |
|
) -> Union[GenerateOutput, torch.LongTensor]: |
|
position_ids = kwargs.pop("position_ids", None) |
|
attention_mask = kwargs.pop("attention_mask", None) |
|
if "inputs_embeds" in kwargs: |
|
raise NotImplementedError("`inputs_embeds` is not supported") |
|
|
|
if vision_input is not None: |
|
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal( |
|
inputs, position_ids, attention_mask, None, None, vision_input, data_types=data_types) |
|
else: |
|
inputs_embeds = self.embed_tokens(inputs) |
|
|
|
return self.get_llm().generate(position_ids=position_ids, attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, **kwargs) |
|
|
|
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): |
|
vision_input = kwargs.pop("vision_input", None) |
|
data_types = kwargs.pop("data_types", None) |
|
inputs = self.get_llm().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, **kwargs) |
|
if vision_input is not None: |
|
inputs["vision_input"] = vision_input |
|
if data_types is not None: |
|
inputs["data_types"] = data_types |
|
return inputs |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], |
|
*model_args, |
|
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, |
|
cache_dir: Optional[Union[str, os.PathLike]] = None, |
|
ignore_mismatched_sizes: bool = False, |
|
force_download: bool = False, |
|
local_files_only: bool = False, |
|
token: Optional[Union[str, bool]] = None, |
|
revision: str = "main", |
|
use_safetensors: bool = None, |
|
**kwargs, |
|
): |
|
|
|
return cls.load_pretrained( |
|
pretrained_model_name_or_path, |
|
*model_args, |
|
config=config, |
|
cache_dir=cache_dir, |
|
ignore_mismatched_sizes=ignore_mismatched_sizes, |
|
force_download=force_download, |
|
local_files_only=local_files_only, |
|
token=token, |
|
revision=revision, |
|
use_safetensors=use_safetensors, |
|
**kwargs, |
|
) |
|
|
|
def get_llm(self): |
|
return self.llm |
|
|
|
def get_vision_tower(self): |
|
return self.vision_tower |
|
|
|
def get_mm_connector(self): |
|
return self.mm_connector |
|
|
|
@classmethod |
|
def load_pretrained(cls, model_path_or_config, *args, **kwargs): |
|
kwargs.pop("config", None) |
|
|
|
if isinstance(model_path_or_config, str): |
|
config = AutoConfig.from_pretrained(model_path_or_config, trust_remote_code=True, **kwargs) |
|
elif isinstance(model_path_or_config, ApolloConfig): |
|
config = model_path_or_config |
|
else: |
|
raise NotImplementedError(f"wrong type, {type(model_path_or_config)} \ |
|
{isinstance(model_path_or_config, ApolloConfig)}") |
|
|
|
model_dtype = getattr(config, "model_dtype", "torch.float16") |
|
if not hasattr(config, "model_dtype"): |
|
warnings.warn("model_dtype not found in config, defaulting to torch.float16.") |
|
config.model_dtype = model_dtype |
|
|
|
with ContextManagers([no_init_weights(_enable=True), ]): |
|
vlm = cls(config, *args, **kwargs) |
|
|
|
if hasattr(vlm, "llm") and hasattr(vlm, "vision_tower") and hasattr(vlm, "mm_connector"): |
|
if vlm.is_loaded: |
|
return vlm |
|
else: |
|
print('loading model failed!') |
|
else: |
|
print('loading model failed!') |
|
|
|
def _encode_mm(self, x): |
|
x = self.get_vision_tower()(x) |
|
x = self.mm_connector(x) |
|
return x |
|
|
|
def encode_mm_minibatch(self, x): |
|
split_sizes = [x_s[0].shape[0] for x_s in x] |
|
x = [torch.split(torch.cat([x_s[i] for x_s in x], dim=0), self.config.encode_batch_size) for i in |
|
range(self.get_vision_tower().num_vision_encoders)] |
|
swapped_x = [] |
|
for i in range(len(x[0])): |
|
swapped_x.append([x_s[i] for x_s in x]) |
|
|
|
features = [] |
|
for xx in swapped_x: |
|
xx = self._encode_mm(xx) |
|
features.append(xx) |
|
x = torch.cat(features, dim=0) |
|
x = torch.split(x, split_sizes, dim=0) |
|
return [xx.contiguous().view(-1, xx.shape[2]) for xx in x] |
|
|
|
def prepare_inputs_labels_for_multimodal( |
|
self, input_ids, position_ids, attention_mask, past_key_values, labels, vision_input, data_types |
|
): |
|
vision_tower = self.get_vision_tower() |
|
if vision_tower is None or vision_input is None or input_ids.shape[1] == 1: |
|
if ( |
|
past_key_values is not None |
|
and vision_tower is not None |
|
and vision_input is not None |
|
and input_ids.shape[1] == 1 |
|
): |
|
target_shape = past_key_values[-1][-1].shape[-2] + 1 |
|
attention_mask = torch.cat( |
|
( |
|
attention_mask, |
|
torch.ones( |
|
( |
|
attention_mask.shape[0], |
|
target_shape - attention_mask.shape[1], |
|
), |
|
dtype=attention_mask.dtype, |
|
device=attention_mask.device, |
|
), |
|
), |
|
dim=1, |
|
) |
|
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 |
|
return ( |
|
input_ids, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
None, |
|
labels, |
|
) |
|
|
|
''' |
|
vision_input is a list of tuples, and data_type is a list of strings: |
|
data_type = ['image', 'video', 'video'..., 'text'] |
|
(for one video and two image encoders) |
|
vision_input = |
|
[ |
|
[image(1, T, C, H, W), image(1, T, C, H, W), image(1, T, C, H, W)], |
|
[video(Nc1, C, T, H, W), video(Nc1, T, C, H, W), video(Nc1, T, C, H, W)], |
|
[video(Nc2, C, T, H, W), video(Nc2, T, C, H, W), video(Nc2, T, C, H, W)], |
|
] |
|
-> video encoders typlically expect (C,T,H,W), images expect (C,H,W). |
|
''' |
|
|
|
merged_mm_features = self.encode_mm_minibatch(vision_input) |
|
|
|
if not getattr(self.config, "tune_language_model", True) and getattr(self.config, "use_mm_start_end", False): |
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
_labels = labels |
|
_position_ids = position_ids |
|
_attention_mask = attention_mask |
|
if attention_mask is None: |
|
attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
|
else: |
|
attention_mask = attention_mask.bool() |
|
if position_ids is None: |
|
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) |
|
if labels is None: |
|
labels = torch.full_like(input_ids, IGNORE_INDEX) |
|
|
|
|
|
input_ids_copy = input_ids.clone() |
|
|
|
input_ids_copy[input_ids_copy == X_TOKEN_INDEX] = 0 |
|
input_embeds = self.get_llm().model.embed_tokens(input_ids_copy) |
|
|
|
input_ids = [ |
|
cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) |
|
] |
|
input_embeds_1 = [ |
|
cur_input_embeds[cur_attention_mask] |
|
for cur_input_embeds, cur_attention_mask in zip(input_embeds, attention_mask) |
|
] |
|
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] |
|
|
|
new_labels = [] |
|
new_input_embeds = [] |
|
|
|
|
|
for batch_idx, (cur_labels, cur_input_ids, mm_features) in enumerate( |
|
zip(labels, input_ids, merged_mm_features)): |
|
cur_input_ids = input_ids[batch_idx] |
|
num_mm = (cur_input_ids == X_TOKEN_INDEX).sum() |
|
if num_mm == 0: |
|
cur_input_embeds_1 = input_embeds_1[batch_idx] |
|
cur_input_embeds = torch.cat([cur_input_embeds_1, mm_features[0:0]], dim=0) |
|
new_input_embeds.append(cur_input_embeds) |
|
new_labels.append(cur_labels) |
|
|
|
continue |
|
|
|
if mm_features.shape[0] != num_mm: |
|
print(data_types[batch_idx]) |
|
assert num_mm == len( |
|
mm_features), f'Error in {data_types[batch_idx]}{num_mm}=/={len(mm_features)} not the same number of vision tokens in and vision embeddings!' |
|
|
|
cur_input_embeds = input_embeds_1[batch_idx] |
|
image_token_indices = ( |
|
[-1] + torch.where(cur_input_ids == X_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] |
|
) |
|
cur_input_ids_noim = [] |
|
cur_labels = labels[batch_idx] |
|
cur_labels_noim = [] |
|
cur_input_embeds_no_im = [] |
|
for i in range(len(image_token_indices) - 1): |
|
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1: image_token_indices[i + 1]]) |
|
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1: image_token_indices[i + 1]]) |
|
cur_input_embeds_no_im.append(cur_input_embeds[image_token_indices[i] + 1: image_token_indices[i + 1]]) |
|
|
|
cur_new_input_embeds = [] |
|
cur_new_labels = [] |
|
for i in range(num_mm + 1): |
|
cur_new_input_embeds.append(cur_input_embeds_no_im[i]) |
|
|
|
cur_new_labels.append(cur_labels_noim[i]) |
|
if i < num_mm: |
|
cur_image_features = mm_features[i:i + 1] |
|
cur_new_input_embeds.append(cur_image_features) |
|
|
|
cur_new_labels.append( |
|
torch.full( |
|
(cur_image_features.shape[0],), |
|
IGNORE_INDEX, |
|
device=cur_labels.device, |
|
dtype=cur_labels.dtype, |
|
) |
|
) |
|
|
|
cur_new_input_embeds = torch.cat(cur_new_input_embeds) |
|
cur_new_labels = torch.cat(cur_new_labels) |
|
|
|
new_input_embeds.append(cur_new_input_embeds) |
|
new_labels.append(cur_new_labels) |
|
|
|
|
|
tokenizer_model_max_length = getattr(self.get_llm().config, "tokenizer_model_max_length", None) |
|
if tokenizer_model_max_length is not None: |
|
if any(len(x) > tokenizer_model_max_length for x in new_input_embeds): |
|
priny("Inputs truncated!") |
|
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] |
|
new_labels = [x[:tokenizer_model_max_length] for x in new_labels] |
|
|
|
max_len = max(x.shape[0] for x in new_input_embeds) |
|
batch_size = len(new_input_embeds) |
|
|
|
new_input_embeds_padded = [] |
|
new_labels_padded = torch.full( |
|
(batch_size, max_len), |
|
IGNORE_INDEX, |
|
dtype=new_labels[0].dtype, |
|
device=new_labels[0].device, |
|
) |
|
attention_mask = torch.zeros( |
|
(batch_size, max_len), |
|
dtype=attention_mask.dtype, |
|
device=attention_mask.device, |
|
) |
|
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) |
|
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): |
|
cur_len = cur_new_embed.shape[0] |
|
if getattr(self.get_llm().config, "tokenizer_padding_side", "right") == "left": |
|
new_input_embeds_padded.append( |
|
torch.cat( |
|
( |
|
torch.zeros( |
|
(max_len - cur_len, cur_new_embed.shape[1]), |
|
dtype=cur_new_embed.dtype, |
|
device=cur_new_embed.device, |
|
), |
|
cur_new_embed, |
|
), |
|
dim=0, |
|
) |
|
) |
|
if cur_len > 0: |
|
new_labels_padded[i, -cur_len:] = cur_new_labels |
|
attention_mask[i, -cur_len:] = True |
|
position_ids[i, -cur_len:] = torch.arange( |
|
0, cur_len, dtype=position_ids.dtype, device=position_ids.device |
|
) |
|
else: |
|
new_input_embeds_padded.append( |
|
torch.cat( |
|
( |
|
cur_new_embed, |
|
torch.zeros( |
|
(max_len - cur_len, cur_new_embed.shape[1]), |
|
dtype=cur_new_embed.dtype, |
|
device=cur_new_embed.device, |
|
), |
|
), |
|
dim=0, |
|
) |
|
) |
|
if cur_len > 0: |
|
new_labels_padded[i, :cur_len] = cur_new_labels |
|
attention_mask[i, :cur_len] = True |
|
position_ids[i, :cur_len] = torch.arange( |
|
0, cur_len, dtype=position_ids.dtype, device=position_ids.device |
|
) |
|
|
|
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) |
|
|
|
if _labels is None: |
|
new_labels = None |
|
else: |
|
new_labels = new_labels_padded |
|
|
|
if _attention_mask is None: |
|
attention_mask = None |
|
else: |
|
attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
|
|
|
if _position_ids is None: |
|
position_ids = None |
|
|
|
return ( |
|
None, |
|
position_ids, |
|
attention_mask, |
|
past_key_values, |
|
new_input_embeds, |
|
new_labels, |
|
) |
|
|