internlm-xcomposer2d5-ol-7b / memory /builder_compressor.py
yhcao's picture
upload models
8e1010d
raw
history blame
6.61 kB
import os
from llava_phi import LlavaPhiForCausalLM, PhiConfig
import torch
import torch.nn as nn
import re
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
from typing import Optional, List
class PhiCompressor(nn.Module):
def __init__(self, compressor):
super().__init__()
self.model_path = compressor
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
self.compressor = LlavaPhiForCausalLM.from_pretrained(self.model_path)
self.select_layer = 15
def forward_video_encoding(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
qs_ids: Optional[torch.LongTensor]= None,
qs_mask: Optional[torch.Tensor] = None,
time_labels: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
projector: Optional[torch.LongTensor] = None,
select_layer: Optional[int] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
full_memory, full_time = self.compressor.forward_video_encoding(
input_ids,
attention_mask,
qs_ids,
qs_mask,
time_labels,
position_ids,
past_key_values,
inputs_embeds,
labels,
use_cache,
output_attentions,
output_hidden_states,
images,
projector,
select_layer,
return_dict
)
return full_memory, full_time
def forward_question(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
qs_ids: Optional[torch.LongTensor]= None,
qs_mask: Optional[torch.Tensor] = None,
time_labels: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
memory: Optional[torch.FloatTensor] = None,
projector: Optional[torch.LongTensor] = None,
select_layer: Optional[int] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
qs_token = self.compressor.forward_question(
input_ids,
attention_mask,
qs_ids,
qs_mask,
time_labels,
position_ids,
past_key_values,
inputs_embeds,
labels,
use_cache,
output_attentions,
output_hidden_states,
memory,
projector,
select_layer,
return_dict
)
return qs_token
def forward_compress(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
qs_ids: Optional[torch.LongTensor]= None,
qs_mask: Optional[torch.Tensor] = None,
time_labels: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
projector: Optional[torch.LongTensor] = None,
select_layer: Optional[int] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
compress_tokens, loss, similarity = self.compressor.forward_token(
input_ids,
attention_mask,
qs_ids,
qs_mask,
time_labels,
position_ids,
past_key_values,
inputs_embeds,
labels,
use_cache,
output_attentions,
output_hidden_states,
images,
projector,
select_layer,
return_dict
)
return compress_tokens, loss, similarity
def forward(self, clips, seqs, compress_mask, qs, qs_mask, time_labels):
return self.forward_compress(input_ids=seqs, attention_mask=compress_mask, qs_ids=qs, qs_mask=qs_mask, images=clips, select_layer=self.select_layer, time_labels=time_labels)
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.compressor.dtype
@property
def device(self):
return self.compressor.device
@property
def config(self):
return self.compressor.config
@property
def hidden_size(self):
return self.config.hidden_size
def build_compressor(compressor_cfg):
compressor = getattr(compressor_cfg, 'mm_compressor', getattr(compressor_cfg, 'compressor', None))
is_absolute_path_exists = os.path.exists(compressor)
if is_absolute_path_exists:
return PhiCompressor(compressor)
raise ValueError(f'Unknown compressor: {compressor}')
def build_compress_projector(config):
projector_type = getattr(config, 'compress_projector_type', 'linear')
if projector_type == 'linear':
return nn.Linear(config.compress_hidden_size, config.hidden_size)
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.compress_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')