Spaces:
Paused
Paused
from typing import Optional | |
import numpy as np | |
import torch | |
from torch import nn | |
from transformers import GPT2Config, GPT2LMHeadModel | |
from transformers.modeling_utils import ModuleUtilsMixin | |
from ...configuration_utils import ConfigMixin, register_to_config | |
from ...models import ModelMixin | |
# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py | |
class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): | |
""" | |
Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to | |
generate text from the UniDiffuser image-text embedding. | |
Parameters: | |
prefix_length (`int`): | |
Max number of prefix tokens that will be supplied to the model. | |
prefix_inner_dim (`int`): | |
The hidden size of the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the | |
CLIP text encoder. | |
prefix_hidden_dim (`int`, *optional*): | |
Hidden dim of the MLP if we encode the prefix. | |
vocab_size (`int`, *optional*, defaults to 50257): | |
Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the | |
`inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. | |
n_positions (`int`, *optional*, defaults to 1024): | |
The maximum sequence length that this model might ever be used with. Typically set this to something large | |
just in case (e.g., 512 or 1024 or 2048). | |
n_embd (`int`, *optional*, defaults to 768): | |
Dimensionality of the embeddings and hidden states. | |
n_layer (`int`, *optional*, defaults to 12): | |
Number of hidden layers in the Transformer encoder. | |
n_head (`int`, *optional*, defaults to 12): | |
Number of attention heads for each attention layer in the Transformer encoder. | |
n_inner (`int`, *optional*, defaults to None): | |
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd | |
activation_function (`str`, *optional*, defaults to `"gelu"`): | |
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. | |
resid_pdrop (`float`, *optional*, defaults to 0.1): | |
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. | |
embd_pdrop (`float`, *optional*, defaults to 0.1): | |
The dropout ratio for the embeddings. | |
attn_pdrop (`float`, *optional*, defaults to 0.1): | |
The dropout ratio for the attention. | |
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): | |
The epsilon to use in the layer normalization layers. | |
initializer_range (`float`, *optional*, defaults to 0.02): | |
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. | |
scale_attn_weights (`bool`, *optional*, defaults to `True`): | |
Scale attention weights by dividing by sqrt(hidden_size).. | |
use_cache (`bool`, *optional*, defaults to `True`): | |
Whether or not the model should return the last key/values attentions (not used by all models). | |
scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): | |
Whether to additionally scale attention weights by `1 / layer_idx + 1`. | |
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): | |
Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention | |
dot-product/softmax to float() when training with mixed precision. | |
""" | |
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"] | |
def __init__( | |
self, | |
prefix_length: int, | |
prefix_inner_dim: int, | |
prefix_hidden_dim: Optional[int] = None, | |
vocab_size: int = 50257, # Start of GPT2 config args | |
n_positions: int = 1024, | |
n_embd: int = 768, | |
n_layer: int = 12, | |
n_head: int = 12, | |
n_inner: Optional[int] = None, | |
activation_function: str = "gelu_new", | |
resid_pdrop: float = 0.1, | |
embd_pdrop: float = 0.1, | |
attn_pdrop: float = 0.1, | |
layer_norm_epsilon: float = 1e-5, | |
initializer_range: float = 0.02, | |
scale_attn_weights: bool = True, | |
use_cache: bool = True, | |
scale_attn_by_inverse_layer_idx: bool = False, | |
reorder_and_upcast_attn: bool = False, | |
): | |
super().__init__() | |
self.prefix_length = prefix_length | |
if prefix_inner_dim != n_embd and prefix_hidden_dim is None: | |
raise ValueError( | |
f"`prefix_hidden_dim` cannot be `None` when `prefix_inner_dim`: {prefix_hidden_dim} and" | |
f" `n_embd`: {n_embd} are not equal." | |
) | |
self.prefix_inner_dim = prefix_inner_dim | |
self.prefix_hidden_dim = prefix_hidden_dim | |
self.encode_prefix = ( | |
nn.Linear(self.prefix_inner_dim, self.prefix_hidden_dim) | |
if self.prefix_hidden_dim is not None | |
else nn.Identity() | |
) | |
self.decode_prefix = ( | |
nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity() | |
) | |
gpt_config = GPT2Config( | |
vocab_size=vocab_size, | |
n_positions=n_positions, | |
n_embd=n_embd, | |
n_layer=n_layer, | |
n_head=n_head, | |
n_inner=n_inner, | |
activation_function=activation_function, | |
resid_pdrop=resid_pdrop, | |
embd_pdrop=embd_pdrop, | |
attn_pdrop=attn_pdrop, | |
layer_norm_epsilon=layer_norm_epsilon, | |
initializer_range=initializer_range, | |
scale_attn_weights=scale_attn_weights, | |
use_cache=use_cache, | |
scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, | |
reorder_and_upcast_attn=reorder_and_upcast_attn, | |
) | |
self.transformer = GPT2LMHeadModel(gpt_config) | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
prefix_embeds: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
labels: Optional[torch.Tensor] = None, | |
): | |
""" | |
Args: | |
input_ids (`torch.Tensor` of shape `(N, max_seq_len)`): | |
Text tokens to use for inference. | |
prefix_embeds (`torch.Tensor` of shape `(N, prefix_length, 768)`): | |
Prefix embedding to preprend to the embedded tokens. | |
attention_mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*): | |
Attention mask for the prefix embedding. | |
labels (`torch.Tensor`, *optional*): | |
Labels to use for language modeling. | |
""" | |
embedding_text = self.transformer.transformer.wte(input_ids) | |
hidden = self.encode_prefix(prefix_embeds) | |
prefix_embeds = self.decode_prefix(hidden) | |
embedding_cat = torch.cat((prefix_embeds, embedding_text), dim=1) | |
if labels is not None: | |
dummy_token = self.get_dummy_token(input_ids.shape[0], input_ids.device) | |
labels = torch.cat((dummy_token, input_ids), dim=1) | |
out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=attention_mask) | |
if self.prefix_hidden_dim is not None: | |
return out, hidden | |
else: | |
return out | |
def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: | |
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) | |
def encode(self, prefix): | |
return self.encode_prefix(prefix) | |
def generate_captions(self, features, eos_token_id, device): | |
""" | |
Generate captions given text embedding features. Returns list[L]. | |
Args: | |
features (`torch.Tensor` of shape `(B, L, D)`): | |
Text embedding features to generate captions from. | |
eos_token_id (`int`): | |
The token ID of the EOS token for the text decoder model. | |
device: | |
Device to perform text generation on. | |
Returns: | |
`List[str]`: A list of strings generated from the decoder model. | |
""" | |
features = torch.split(features, 1, dim=0) | |
generated_tokens = [] | |
generated_seq_lengths = [] | |
for feature in features: | |
feature = self.decode_prefix(feature.to(device)) # back to the clip feature | |
# Only support beam search for now | |
output_tokens, seq_lengths = self.generate_beam( | |
input_embeds=feature, device=device, eos_token_id=eos_token_id | |
) | |
generated_tokens.append(output_tokens[0]) | |
generated_seq_lengths.append(seq_lengths[0]) | |
generated_tokens = torch.stack(generated_tokens) | |
generated_seq_lengths = torch.stack(generated_seq_lengths) | |
return generated_tokens, generated_seq_lengths | |
def generate_beam( | |
self, | |
input_ids=None, | |
input_embeds=None, | |
device=None, | |
beam_size: int = 5, | |
entry_length: int = 67, | |
temperature: float = 1.0, | |
eos_token_id: Optional[int] = None, | |
): | |
""" | |
Generates text using the given tokenizer and text prompt or token embedding via beam search. This | |
implementation is based on the beam search implementation from the [original UniDiffuser | |
code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89). | |
Args: | |
eos_token_id (`int`, *optional*): | |
The token ID of the EOS token for the text decoder model. | |
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): | |
Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds` | |
must be supplied. | |
input_embeds (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): | |
An embedded representation to directly pass to the transformer as a prefix for beam search. One of | |
`input_ids` and `input_embeds` must be supplied. | |
device: | |
The device to perform beam search on. | |
beam_size (`int`, *optional*, defaults to `5`): | |
The number of best states to store during beam search. | |
entry_length (`int`, *optional*, defaults to `67`): | |
The number of iterations to run beam search. | |
temperature (`float`, *optional*, defaults to 1.0): | |
The temperature to use when performing the softmax over logits from the decoding model. | |
Returns: | |
`Tuple(torch.Tensor, torch.Tensor)`: A tuple of tensors where the first element is a tensor of generated | |
token sequences sorted by score in descending order, and the second element is the sequence lengths | |
corresponding to those sequences. | |
""" | |
# Generates text until stop_token is reached using beam search with the desired beam size. | |
stop_token_index = eos_token_id | |
tokens = None | |
scores = None | |
seq_lengths = torch.ones(beam_size, device=device, dtype=torch.int) | |
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) | |
if input_embeds is not None: | |
generated = input_embeds | |
else: | |
generated = self.transformer.transformer.wte(input_ids) | |
for i in range(entry_length): | |
outputs = self.transformer(inputs_embeds=generated) | |
logits = outputs.logits | |
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) | |
logits = logits.softmax(-1).log() | |
if scores is None: | |
scores, next_tokens = logits.topk(beam_size, -1) | |
generated = generated.expand(beam_size, *generated.shape[1:]) | |
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) | |
if tokens is None: | |
tokens = next_tokens | |
else: | |
tokens = tokens.expand(beam_size, *tokens.shape[1:]) | |
tokens = torch.cat((tokens, next_tokens), dim=1) | |
else: | |
logits[is_stopped] = -float(np.inf) | |
logits[is_stopped, 0] = 0 | |
scores_sum = scores[:, None] + logits | |
seq_lengths[~is_stopped] += 1 | |
scores_sum_average = scores_sum / seq_lengths[:, None] | |
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) | |
next_tokens_source = next_tokens // scores_sum.shape[1] | |
seq_lengths = seq_lengths[next_tokens_source] | |
next_tokens = next_tokens % scores_sum.shape[1] | |
next_tokens = next_tokens.unsqueeze(1) | |
tokens = tokens[next_tokens_source] | |
tokens = torch.cat((tokens, next_tokens), dim=1) | |
generated = generated[next_tokens_source] | |
scores = scores_sum_average * seq_lengths | |
is_stopped = is_stopped[next_tokens_source] | |
next_token_embed = self.transformer.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) | |
generated = torch.cat((generated, next_token_embed), dim=1) | |
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() | |
if is_stopped.all(): | |
break | |
scores = scores / seq_lengths | |
order = scores.argsort(descending=True) | |
# tokens tensors are already padded to max_seq_length | |
output_texts = [tokens[i] for i in order] | |
output_texts = torch.stack(output_texts, dim=0) | |
seq_lengths = torch.tensor([seq_lengths[i] for i in order], dtype=seq_lengths.dtype) | |
return output_texts, seq_lengths | |