Spaces:
Runtime error
Runtime error
import torch | |
from peft import LoraConfig, get_peft_model | |
from torch import nn | |
import os | |
class phEYE(nn.Module): | |
def __init__( | |
self, | |
vision_encoder: nn.Module, | |
lang_encoder: nn.Module, | |
vis_dim: int, | |
dtype: torch.dtype, | |
cross_attn_every_n_layers: int = 1, | |
gradient_checkpointing: bool = False, | |
reduce_factor = 1, | |
from_layer = 0 | |
): | |
""" | |
Args: | |
vision_encoder (nn.Module): module with OpenCLIP model | |
lang_encoder (nn.Module): HF causal language model | |
vis_dim (int): Dimension of the visual features. | |
Visual features are projected to match this shape along the last dimension. | |
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. | |
""" | |
super().__init__() | |
self.vis_dim = vis_dim | |
if hasattr(lang_encoder.config, "d_model"): | |
self.lang_dim = lang_encoder.config.d_model # mpt uses d_model | |
else: | |
self.lang_dim = lang_encoder.config.hidden_size | |
self.vision_encoder = vision_encoder | |
self.lang_encoder = lang_encoder | |
self.lang_encoder.init_pheye( | |
lang_hidden_size=self.lang_dim, | |
vis_hidden_size=self.vis_dim, | |
cross_attn_every_n_layers=cross_attn_every_n_layers, | |
gradient_checkpointing=gradient_checkpointing, | |
reduce_factor=reduce_factor, | |
from_layer=from_layer, | |
dtype=dtype | |
) | |
self._use_gradient_checkpointing = gradient_checkpointing | |
def forward( | |
self, | |
vision_x: list, | |
lang_x: torch.Tensor, | |
attention_mask: torch.Tensor = None, | |
labels: torch.Tensor = None, | |
clear_conditioned_layers: bool = True, | |
past_key_values = None, | |
use_cache: bool = False, | |
device="cpu", | |
is_textcaps = False | |
): | |
""" | |
Forward pass of phEYE. | |
Args: | |
vision_x (list): Vision input | |
shape (B, C, H, W) | |
lang_x (torch.Tensor): Language input ids | |
shape (B, txt_seq) | |
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. | |
labels (torch.Tensor, optional): Labels. Defaults to None. | |
clear_conditioned_layers: if True, clear the conditioned layers | |
once the foward pass is completed. Set this to false if the | |
same set of images will be reused in another subsequent | |
forward pass. | |
past_key_values: pre-computed values to pass to language model. | |
See past_key_values documentation in Hugging Face | |
CausalLM models. | |
use_cache: whether to use cached key values. See use_cache | |
documentation in Hugging Face CausalLM models. | |
""" | |
assert ( | |
self.lang_encoder.initialized_pheye | |
), "Wrapper layers are not initialized. Please call `initialized_pheye` first." | |
assert ( | |
self.lang_encoder._use_cached_vision_x or vision_x is not None | |
), "Must provide either vision_x or have precached media using cache_media()." | |
if self.lang_encoder._use_cached_vision_x: | |
# Case: use cached; vision_x should be cached and other | |
# vision-related inputs should not be provided. | |
assert ( | |
vision_x is None | |
), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first." | |
assert self.lang_encoder.is_conditioned() | |
else: | |
# Case: do not use caching (i.e. this is a standard forward pass); | |
self._encode_vision_x(vision_x=vision_x, device=device, is_textcaps=is_textcaps) | |
#print(f"Text features shape: {lang_x.shape}") | |
output = self.lang_encoder( | |
input_ids=lang_x, | |
attention_mask=attention_mask, | |
labels=labels, | |
past_key_values=past_key_values, | |
use_cache=use_cache, | |
) | |
if clear_conditioned_layers: | |
self.lang_encoder.clear_conditioned_layers() | |
return output | |
def generate( | |
self, | |
vision_x: list, | |
lang_x: torch.Tensor, | |
attention_mask: torch.Tensor = None, | |
device = "cpu", | |
**kwargs, | |
): | |
""" | |
Generate text conditioned on vision and language inputs. | |
Args: | |
vision_x (list): Vision input | |
shape (B, C, H, W) | |
images in the same chunk are collated along T_img, and frames are collated along F | |
currently only F=1 is supported (single-frame videos) | |
lang_x (torch.Tensor): Language input | |
shape (B, T_txt) | |
**kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs: | |
max_length (int, optional): Maximum length of the output. Defaults to None. | |
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. | |
num_beams (int, optional): Number of beams. Defaults to 1. | |
max_new_tokens (int, optional): Maximum new tokens. Defaults to None. | |
temperature (float, optional): Temperature. Defaults to 1.0. | |
top_k (int, optional): Top k. Defaults to 50. | |
top_p (float, optional): Top p. Defaults to 1.0. | |
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. | |
length_penalty (float, optional): Length penalty. Defaults to 1.0. | |
num_return_sequences (int, optional): Number of return sequences. Defaults to 1. | |
do_sample (bool, optional): Do sample. Defaults to False. | |
early_stopping (bool, optional): Early stopping. Defaults to False. | |
Returns: | |
torch.Tensor: lang_x with generated tokens appended to it | |
""" | |
num_beams = kwargs.pop("num_beams", 1) | |
self.lang_encoder._use_cached_vision_x = True | |
self._encode_vision_x(vision_x=vision_x, device=device, repeat=num_beams) | |
output = self.lang_encoder.generate( | |
input_ids=lang_x, | |
attention_mask=attention_mask, | |
num_beams=num_beams, | |
**kwargs, | |
) | |
self.lang_encoder.clear_conditioned_layers() | |
self.lang_encoder._use_cached_vision_x = False | |
return output | |
def _encode_vision_x(self, vision_x: list, device="cpu", repeat = 1, is_textcaps = False): | |
""" | |
Compute vision features by passing images through vision encoder and conditioning language model. | |
Args: | |
vision_x (list): Vision input | |
shape (B, C, H, W) | |
""" | |
if is_textcaps: | |
vision_x = vision_x[::5] | |
repeat = 5 | |
vision_x = self.vision_encoder(vision_x, device=device) | |
if repeat > 1: | |
vision_x = vision_x.repeat_interleave(repeat, dim=0) | |
for layer in self.lang_encoder._get_decoder_layers(): | |
layer.condition_vis_x(vision_x) | |
def cache_media(self, vision_x: list, device="cpu"): | |
""" | |
Cache vision_x features from list of images for log-likelihood evaluation | |
This is not meant to be used to cache things for generate(). | |
Args: | |
vision_x (torch.Tensor): Vision input | |
shape (B, F, C, H, W) | |
""" | |
self._encode_vision_x(vision_x=vision_x, device=device) | |
self.lang_encoder._use_cached_vision_x = True | |
def uncache_media(self): | |
""" | |
Clear all conditioning. | |
""" | |
self.lang_encoder.clear_conditioned_layers() | |
self.lang_encoder._use_cached_vision_x = False | |
def save_model(self, _path): | |
os.mkdir(_path) | |
torch.save(self.vision_encoder.state_dict(), _path+"vision_encoder.pt") | |
torch.save(self.lang_encoder.state_dict(), _path+"lang_encoder.pt") | |
def add_lora_decoder(self): | |
config = LoraConfig( | |
r=16, | |
lora_alpha=32, | |
target_modules=["q_proj", "k_proj", "v_proj", "dense", "fc1", "fc2"], | |
lora_dropout=0.05, | |
bias="none" | |
) | |
self.lang_encoder.old_decoder_blocks = get_peft_model(self.lang_encoder.old_decoder_blocks, config) | |
def merge_and_unload(self): | |
self.lang_encoder.old_decoder_blocks = self.lang_encoder.old_decoder_blocks.merge_and_unload() | |