|
import copy |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import CrossEntropyLoss |
|
from transformers.models.t5 import modeling_t5 |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from transformers.utils import ( |
|
add_start_docstrings_to_model_forward, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
|
|
from decoder_only_t5.config import DecoderOnlyT5Config |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
_CONFIG_FOR_DOC = "DecoderOnlyT5Config" |
|
|
|
|
|
class DecoderOnlyT5LayerFF(modeling_t5.T5LayerFF): |
|
def __init__(self, config: DecoderOnlyT5Config): |
|
super(modeling_t5.T5LayerFF, self).__init__() |
|
if config.is_gated_act: |
|
self.DenseReluDense = modeling_t5.T5DenseGatedActDense(config) |
|
else: |
|
self.DenseReluDense = modeling_t5.T5DenseActDense(config) |
|
|
|
if not config.parallel_layers: |
|
self.layer_norm = modeling_t5.T5LayerNorm( |
|
config.d_model, eps=config.layer_norm_epsilon |
|
) |
|
else: |
|
self.layer_norm = nn.Identity() |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
|
|
|
|
|
|
class T5DecoderOnlyRotaryEmbedding(nn.Module): |
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): |
|
super().__init__() |
|
|
|
self.dim = dim |
|
self.max_position_embeddings = max_position_embeddings |
|
self.base = base |
|
inv_freq = 1.0 / ( |
|
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) |
|
) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
self._set_cos_sin_cache( |
|
seq_len=max_position_embeddings, |
|
device=self.inv_freq.device, |
|
dtype=torch.get_default_dtype(), |
|
) |
|
|
|
def _set_cos_sin_cache(self, seq_len, device, dtype): |
|
self.max_seq_len_cached = seq_len |
|
t = torch.arange( |
|
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype |
|
) |
|
|
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) |
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) |
|
|
|
def forward(self, x, seq_len=None): |
|
|
|
if seq_len > self.max_seq_len_cached: |
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) |
|
|
|
return ( |
|
self.cos_cached[:seq_len].to(dtype=x.dtype), |
|
self.sin_cached[:seq_len].to(dtype=x.dtype), |
|
) |
|
|
|
|
|
def rotate_half(x): |
|
"""Rotates half the hidden dims of the input.""" |
|
x1 = x[..., : x.shape[-1] // 2] |
|
x2 = x[..., x.shape[-1] // 2 :] |
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): |
|
"""Applies Rotary Position Embedding to the query and key tensors. |
|
|
|
Args: |
|
q (`torch.Tensor`): The query tensor. |
|
k (`torch.Tensor`): The key tensor. |
|
cos (`torch.Tensor`): The cosine part of the rotary embedding. |
|
sin (`torch.Tensor`): The sine part of the rotary embedding. |
|
position_ids (`torch.Tensor`): |
|
The position indices of the tokens corresponding to the query and key tensors. For example, this can be |
|
used to pass offsetted position ids when working with a KV-cache. |
|
unsqueeze_dim (`int`, *optional*, defaults to 1): |
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
|
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
|
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
|
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
|
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
|
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
|
Returns: |
|
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
|
""" |
|
cos = cos[position_ids].unsqueeze(unsqueeze_dim) |
|
sin = sin[position_ids].unsqueeze(unsqueeze_dim) |
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
return q_embed, k_embed |
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand( |
|
batch, num_key_value_heads, n_rep, slen, head_dim |
|
) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
class DecoderOnlyT5Attention(modeling_t5.T5Attention): |
|
""" |
|
Supports both multi-head and multi-query attention. |
|
https://arxiv.org/abs/1911.02150 |
|
https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/components/attention/dense_attention.py#L292 |
|
""" |
|
|
|
def __init__(self, config: DecoderOnlyT5Config, has_relative_attention_bias=False): |
|
super(modeling_t5.T5Attention, self).__init__() |
|
self.is_decoder = config.is_decoder |
|
self.has_relative_attention_bias = has_relative_attention_bias |
|
self.relative_attention_num_buckets = config.relative_attention_num_buckets |
|
self.relative_attention_max_distance = config.relative_attention_max_distance |
|
self.d_model = config.d_model |
|
self.key_value_proj_dim = config.d_kv |
|
self.n_heads = config.num_heads |
|
self.n_kv_heads = 1 if config.multi_query_attention else self.n_heads |
|
self.n_kv_groups = self.n_heads // self.n_kv_heads |
|
self.dropout = config.dropout_rate |
|
self.inner_dim = self.n_heads * self.key_value_proj_dim |
|
self.kv_inner_dim = self.n_kv_heads * self.key_value_proj_dim |
|
if config.use_rotary_embedding: |
|
self.rotary_embedding = T5DecoderOnlyRotaryEmbedding( |
|
self.key_value_proj_dim, |
|
max_position_embeddings=config.relative_attention_max_distance, |
|
base=config.rotary_embedding_max_timescale, |
|
) |
|
else: |
|
self.rotary_embedding = None |
|
|
|
|
|
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) |
|
self.k = nn.Linear(self.d_model, self.kv_inner_dim, bias=False) |
|
self.v = nn.Linear(self.d_model, self.kv_inner_dim, bias=False) |
|
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) |
|
|
|
if self.has_relative_attention_bias: |
|
self.relative_attention_bias = nn.Embedding( |
|
self.relative_attention_num_buckets, self.n_heads |
|
) |
|
self.pruned_heads = set() |
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
mask=None, |
|
key_value_states=None, |
|
position_bias=None, |
|
position_ids=None, |
|
past_key_value=None, |
|
layer_head_mask=None, |
|
query_length=None, |
|
use_cache=False, |
|
output_attentions=False, |
|
): |
|
""" |
|
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). |
|
""" |
|
|
|
|
|
|
|
batch_size, seq_length = hidden_states.shape[:2] |
|
|
|
real_seq_length = seq_length |
|
|
|
if past_key_value is not None: |
|
if len(past_key_value) != 2: |
|
raise ValueError( |
|
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" |
|
) |
|
real_seq_length += ( |
|
past_key_value[0].shape[2] if query_length is None else query_length |
|
) |
|
|
|
key_length = ( |
|
real_seq_length if key_value_states is None else key_value_states.shape[1] |
|
) |
|
|
|
def shape(states, n_heads): |
|
"""projection""" |
|
return states.view( |
|
batch_size, -1, n_heads, self.key_value_proj_dim |
|
).transpose(1, 2) |
|
|
|
def unshape(states): |
|
"""reshape""" |
|
return ( |
|
states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) |
|
) |
|
|
|
def project(hidden_states, proj_layer, key_value_states, past_key_value): |
|
"""projects hidden states correctly to key/query states""" |
|
if key_value_states is None: |
|
|
|
|
|
hidden_states = shape(proj_layer(hidden_states), self.n_kv_heads) |
|
elif past_key_value is None: |
|
|
|
|
|
hidden_states = shape(proj_layer(key_value_states), self.n_kv_heads) |
|
return hidden_states |
|
|
|
def concat_past_key_value(hidden_states, past_key_value, key_value_states): |
|
if key_value_states is None: |
|
|
|
|
|
hidden_states = torch.cat([past_key_value, hidden_states], dim=2) |
|
elif past_key_value.shape[2] != key_value_states.shape[1]: |
|
|
|
|
|
|
|
|
|
raise NotImplementedError( |
|
"cross attention with RoPE and past KV is not implemented" |
|
) |
|
|
|
else: |
|
|
|
hidden_states = past_key_value |
|
return hidden_states |
|
|
|
|
|
query_states = shape( |
|
self.q(hidden_states), self.n_heads |
|
) |
|
|
|
|
|
key_states = project(hidden_states, self.k, key_value_states, past_key_value) |
|
value_states = project(hidden_states, self.v, key_value_states, past_key_value) |
|
|
|
|
|
if self.rotary_embedding is not None: |
|
kv_seq_len = key_states.shape[-2] |
|
if past_key_value: |
|
kv_seq_len += past_key_value[0].shape[-2] |
|
cos, sin = self.rotary_embedding(query_states, seq_len=kv_seq_len) |
|
query_states, key_states = apply_rotary_pos_emb( |
|
query_states, key_states, cos, sin, position_ids |
|
) |
|
|
|
|
|
if past_key_value is not None: |
|
key_states = concat_past_key_value( |
|
key_states, |
|
past_key_value[0], |
|
key_value_states, |
|
) |
|
value_states = concat_past_key_value( |
|
value_states, |
|
past_key_value[1], |
|
key_value_states, |
|
) |
|
|
|
|
|
key_states = repeat_kv(key_states, self.n_kv_groups) |
|
value_states = repeat_kv(value_states, self.n_kv_groups) |
|
|
|
|
|
scores = torch.matmul( |
|
query_states, key_states.transpose(3, 2) |
|
) |
|
|
|
if position_bias is None: |
|
if not self.has_relative_attention_bias: |
|
position_bias = torch.zeros( |
|
(1, self.n_heads, real_seq_length, key_length), |
|
device=scores.device, |
|
dtype=scores.dtype, |
|
) |
|
if self.gradient_checkpointing and self.training: |
|
position_bias.requires_grad = True |
|
else: |
|
position_bias = self.compute_bias( |
|
real_seq_length, key_length, device=scores.device |
|
) |
|
|
|
|
|
|
|
if past_key_value is not None: |
|
position_bias = position_bias[:, :, -hidden_states.size(1) :, :] |
|
|
|
if mask is not None: |
|
position_bias = ( |
|
position_bias + mask |
|
) |
|
|
|
if self.pruned_heads: |
|
mask = torch.ones(position_bias.shape[1]) |
|
mask[list(self.pruned_heads)] = 0 |
|
position_bias_masked = position_bias[:, mask.bool()] |
|
else: |
|
position_bias_masked = position_bias |
|
|
|
scores += position_bias_masked |
|
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( |
|
scores |
|
) |
|
attn_weights = nn.functional.dropout( |
|
attn_weights, p=self.dropout, training=self.training |
|
) |
|
|
|
|
|
if layer_head_mask is not None: |
|
attn_weights = attn_weights * layer_head_mask |
|
|
|
attn_output = unshape( |
|
torch.matmul(attn_weights, value_states) |
|
) |
|
attn_output = self.o(attn_output) |
|
|
|
present_key_value_state = ( |
|
(key_states, value_states) if (self.is_decoder and use_cache) else None |
|
) |
|
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) |
|
|
|
if output_attentions: |
|
outputs = outputs + (attn_weights,) |
|
return outputs |
|
|
|
|
|
class DecoderOnlyT5LayerSelfAttention(modeling_t5.T5LayerSelfAttention): |
|
def __init__(self, config, has_relative_attention_bias=False): |
|
super(modeling_t5.T5LayerSelfAttention, self).__init__() |
|
self.SelfAttention = DecoderOnlyT5Attention( |
|
config, has_relative_attention_bias=has_relative_attention_bias |
|
) |
|
self.layer_norm = modeling_t5.T5LayerNorm( |
|
config.d_model, eps=config.layer_norm_epsilon |
|
) |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
self.parallel_layers = config.parallel_layers |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
position_bias=None, |
|
position_ids=None, |
|
layer_head_mask=None, |
|
past_key_value=None, |
|
use_cache=False, |
|
output_attentions=False, |
|
): |
|
if not self.parallel_layers: |
|
x = self.layer_norm(hidden_states) |
|
else: |
|
x = hidden_states |
|
attention_output = self.SelfAttention( |
|
x, |
|
mask=attention_mask, |
|
position_bias=position_bias, |
|
position_ids=position_ids, |
|
layer_head_mask=layer_head_mask, |
|
past_key_value=past_key_value, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
if not self.parallel_layers: |
|
|
|
|
|
hidden_states = hidden_states + self.dropout(attention_output[0]) |
|
else: |
|
hidden_states = attention_output[0] |
|
outputs = (hidden_states,) + attention_output[ |
|
1: |
|
] |
|
return outputs |
|
|
|
|
|
class DecoderOnlyT5Block(modeling_t5.T5Block): |
|
def __init__(self, config, has_relative_attention_bias=False): |
|
super(modeling_t5.T5Block, self).__init__() |
|
self.is_decoder = config.is_decoder |
|
self.is_decoder_only = config.is_decoder_only |
|
self.layer = nn.ModuleList() |
|
self.layer.append( |
|
DecoderOnlyT5LayerSelfAttention( |
|
config, has_relative_attention_bias=has_relative_attention_bias |
|
) |
|
) |
|
if self.is_decoder: |
|
if config.is_decoder_only: |
|
self.layer.append(nn.Identity()) |
|
else: |
|
self.layer.append(modeling_t5.T5LayerCrossAttention(config)) |
|
self.parallel_layers = config.parallel_layers |
|
self.layer.append(DecoderOnlyT5LayerFF(config)) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
position_bias=None, |
|
position_ids=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
encoder_decoder_position_bias=None, |
|
layer_head_mask=None, |
|
cross_attn_layer_head_mask=None, |
|
past_key_value=None, |
|
use_cache=False, |
|
output_attentions=False, |
|
return_dict=True, |
|
): |
|
if past_key_value is not None: |
|
if not self.is_decoder: |
|
logger.warning( |
|
"`past_key_values` is passed to the encoder. Please make sure this is intended." |
|
) |
|
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 |
|
|
|
if len(past_key_value) != expected_num_past_key_values: |
|
raise ValueError( |
|
f"There should be {expected_num_past_key_values} past states. " |
|
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" |
|
f"Got {len(past_key_value)} past key / value states" |
|
) |
|
|
|
self_attn_past_key_value = past_key_value[:2] |
|
cross_attn_past_key_value = past_key_value[2:] |
|
else: |
|
self_attn_past_key_value, cross_attn_past_key_value = None, None |
|
|
|
ff_layer = self.layer[-1] |
|
if self.parallel_layers: |
|
|
|
x = self.layer[0].layer_norm(hidden_states) |
|
ff_output = ff_layer(x) |
|
else: |
|
x = hidden_states |
|
|
|
self_attention_outputs = self.layer[0]( |
|
x, |
|
attention_mask=attention_mask, |
|
position_bias=position_bias, |
|
position_ids=position_ids, |
|
layer_head_mask=layer_head_mask, |
|
past_key_value=self_attn_past_key_value, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
x, present_key_value_state = self_attention_outputs[:2] |
|
attention_outputs = self_attention_outputs[ |
|
2: |
|
] |
|
|
|
|
|
if x.dtype == torch.float16: |
|
clamp_value = torch.where( |
|
torch.isinf(x).any(), |
|
torch.finfo(x.dtype).max - 1000, |
|
torch.finfo(x.dtype).max, |
|
) |
|
x = torch.clamp(x, min=-clamp_value, max=clamp_value) |
|
|
|
do_cross_attention = ( |
|
self.is_decoder |
|
and not self.is_decoder_only |
|
and encoder_hidden_states is not None |
|
) |
|
if do_cross_attention: |
|
|
|
|
|
if present_key_value_state is not None: |
|
query_length = present_key_value_state[0].shape[2] |
|
else: |
|
query_length = None |
|
|
|
cross_attention_outputs = self.layer[1]( |
|
x, |
|
key_value_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
position_bias=encoder_decoder_position_bias, |
|
|
|
layer_head_mask=cross_attn_layer_head_mask, |
|
past_key_value=cross_attn_past_key_value, |
|
query_length=query_length, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
x = cross_attention_outputs[0] |
|
|
|
|
|
if x.dtype == torch.float16: |
|
clamp_value = torch.where( |
|
torch.isinf(x).any(), |
|
torch.finfo(x.dtype).max - 1000, |
|
torch.finfo(x.dtype).max, |
|
) |
|
x = torch.clamp(x, min=-clamp_value, max=clamp_value) |
|
|
|
|
|
if present_key_value_state is not None: |
|
present_key_value_state = ( |
|
present_key_value_state + cross_attention_outputs[1] |
|
) |
|
|
|
|
|
attention_outputs = attention_outputs + cross_attention_outputs[2:] |
|
|
|
if self.parallel_layers: |
|
|
|
x = x + ff_output |
|
x *= 2**-0.5 |
|
hidden_states = hidden_states + self.layer[0].dropout(x) |
|
else: |
|
hidden_states = ff_layer(x) |
|
|
|
|
|
if hidden_states.dtype == torch.float16: |
|
clamp_value = torch.where( |
|
torch.isinf(hidden_states).any(), |
|
torch.finfo(hidden_states.dtype).max - 1000, |
|
torch.finfo(hidden_states.dtype).max, |
|
) |
|
hidden_states = torch.clamp( |
|
hidden_states, min=-clamp_value, max=clamp_value |
|
) |
|
|
|
outputs = (hidden_states,) |
|
|
|
if use_cache: |
|
outputs = outputs + (present_key_value_state,) + attention_outputs |
|
else: |
|
outputs = outputs + attention_outputs |
|
|
|
return outputs |
|
|
|
|
|
class DecoderOnlyT5Stack(modeling_t5.T5Stack): |
|
def __init__(self, config, embed_tokens=None): |
|
super(modeling_t5.T5Stack, self).__init__(config) |
|
|
|
self.embed_tokens = embed_tokens |
|
self.is_decoder = config.is_decoder |
|
|
|
self.block = nn.ModuleList( |
|
[ |
|
DecoderOnlyT5Block( |
|
config, |
|
has_relative_attention_bias=( |
|
config.has_relative_attention_bias and bool(i == 0) |
|
), |
|
) |
|
for i in range(config.num_layers) |
|
] |
|
) |
|
if not config.parallel_layers: |
|
self.final_layer_norm = modeling_t5.T5LayerNorm( |
|
config.d_model, eps=config.layer_norm_epsilon |
|
) |
|
else: |
|
self.final_layer_norm = nn.Identity() |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
|
|
|
|
self.post_init() |
|
|
|
self.model_parallel = False |
|
self.device_map = None |
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
position_ids=None, |
|
attention_mask=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
inputs_embeds=None, |
|
head_mask=None, |
|
cross_attn_head_mask=None, |
|
past_key_values=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
|
|
if self.model_parallel: |
|
torch.cuda.set_device(self.first_device) |
|
self.embed_tokens = self.embed_tokens.to(self.first_device) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
err_msg_prefix = "decoder_" if self.is_decoder else "" |
|
raise ValueError( |
|
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" |
|
) |
|
elif input_ids is not None: |
|
input_shape = input_ids.size() |
|
input_ids = input_ids.view(-1, input_shape[-1]) |
|
elif inputs_embeds is not None: |
|
input_shape = inputs_embeds.size()[:-1] |
|
else: |
|
err_msg_prefix = "decoder_" if self.is_decoder else "" |
|
raise ValueError( |
|
f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" |
|
) |
|
|
|
if position_ids is None: |
|
seq_length = input_ids.shape[1] |
|
past_key_values_length = ( |
|
0 if past_key_values is None else past_key_values[0][0].shape[2] |
|
) |
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
position_ids = torch.arange( |
|
past_key_values_length, |
|
seq_length + past_key_values_length, |
|
dtype=torch.long, |
|
device=device, |
|
) |
|
position_ids = position_ids.unsqueeze(0) |
|
|
|
if inputs_embeds is None: |
|
if self.embed_tokens is None: |
|
raise ValueError( |
|
"You have to initialize the model with valid token embeddings" |
|
) |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
batch_size, seq_length = input_shape |
|
|
|
|
|
mask_seq_length = ( |
|
past_key_values[0][0].shape[2] + seq_length |
|
if past_key_values is not None |
|
else seq_length |
|
) |
|
|
|
if use_cache is True: |
|
if not self.is_decoder: |
|
raise ValueError( |
|
f"`use_cache` can only be set to `True` if {self} is used as a decoder" |
|
) |
|
|
|
if attention_mask is None: |
|
attention_mask = torch.ones( |
|
batch_size, mask_seq_length, device=inputs_embeds.device |
|
) |
|
if ( |
|
self.is_decoder |
|
and encoder_attention_mask is None |
|
and encoder_hidden_states is not None |
|
): |
|
encoder_seq_length = encoder_hidden_states.shape[1] |
|
encoder_attention_mask = torch.ones( |
|
batch_size, |
|
encoder_seq_length, |
|
device=inputs_embeds.device, |
|
dtype=torch.long, |
|
) |
|
|
|
|
|
if past_key_values is None: |
|
past_key_values = [None] * len(self.block) |
|
|
|
|
|
|
|
extended_attention_mask = self.get_extended_attention_mask( |
|
attention_mask, input_shape |
|
) |
|
|
|
|
|
|
|
if self.is_decoder and encoder_hidden_states is not None: |
|
( |
|
encoder_batch_size, |
|
encoder_sequence_length, |
|
_, |
|
) = encoder_hidden_states.size() |
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
|
if encoder_attention_mask is None: |
|
encoder_attention_mask = torch.ones( |
|
encoder_hidden_shape, device=inputs_embeds.device |
|
) |
|
encoder_extended_attention_mask = self.invert_attention_mask( |
|
encoder_attention_mask |
|
) |
|
else: |
|
encoder_extended_attention_mask = None |
|
|
|
if self.gradient_checkpointing and self.training: |
|
if use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
) |
|
use_cache = False |
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_layers) |
|
cross_attn_head_mask = self.get_head_mask( |
|
cross_attn_head_mask, self.config.num_layers |
|
) |
|
present_key_value_states = () if use_cache else None |
|
all_hidden_states = () if output_hidden_states else None |
|
all_attentions = () if output_attentions else None |
|
all_cross_attentions = () if (output_attentions and self.is_decoder) else None |
|
position_bias = None |
|
encoder_decoder_position_bias = None |
|
|
|
hidden_states = self.dropout(inputs_embeds) |
|
|
|
for i, (layer_module, past_key_value) in enumerate( |
|
zip(self.block, past_key_values) |
|
): |
|
layer_head_mask = head_mask[i] |
|
cross_attn_layer_head_mask = cross_attn_head_mask[i] |
|
|
|
if self.model_parallel: |
|
torch.cuda.set_device(hidden_states.device) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attention_mask.to(hidden_states.device) |
|
if position_bias is not None: |
|
position_bias = position_bias.to(hidden_states.device) |
|
if encoder_hidden_states is not None: |
|
encoder_hidden_states = encoder_hidden_states.to( |
|
hidden_states.device |
|
) |
|
if encoder_extended_attention_mask is not None: |
|
encoder_extended_attention_mask = ( |
|
encoder_extended_attention_mask.to(hidden_states.device) |
|
) |
|
if encoder_decoder_position_bias is not None: |
|
encoder_decoder_position_bias = encoder_decoder_position_bias.to( |
|
hidden_states.device |
|
) |
|
if layer_head_mask is not None: |
|
layer_head_mask = layer_head_mask.to(hidden_states.device) |
|
if cross_attn_layer_head_mask is not None: |
|
cross_attn_layer_head_mask = cross_attn_layer_head_mask.to( |
|
hidden_states.device |
|
) |
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
layer_module.forward, |
|
hidden_states, |
|
extended_attention_mask, |
|
position_bias, |
|
encoder_hidden_states, |
|
encoder_extended_attention_mask, |
|
encoder_decoder_position_bias, |
|
layer_head_mask, |
|
cross_attn_layer_head_mask, |
|
None, |
|
use_cache, |
|
output_attentions, |
|
) |
|
else: |
|
layer_outputs = layer_module( |
|
hidden_states, |
|
attention_mask=extended_attention_mask, |
|
position_bias=position_bias, |
|
position_ids=position_ids, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_extended_attention_mask, |
|
encoder_decoder_position_bias=encoder_decoder_position_bias, |
|
layer_head_mask=layer_head_mask, |
|
cross_attn_layer_head_mask=cross_attn_layer_head_mask, |
|
past_key_value=past_key_value, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
|
|
|
|
if use_cache is False: |
|
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] |
|
|
|
hidden_states, present_key_value_state = layer_outputs[:2] |
|
|
|
|
|
|
|
|
|
position_bias = layer_outputs[2] |
|
if self.is_decoder and encoder_hidden_states is not None: |
|
encoder_decoder_position_bias = layer_outputs[ |
|
4 if output_attentions else 3 |
|
] |
|
|
|
if use_cache: |
|
present_key_value_states = present_key_value_states + ( |
|
present_key_value_state, |
|
) |
|
|
|
if output_attentions: |
|
all_attentions = all_attentions + (layer_outputs[3],) |
|
if self.is_decoder: |
|
all_cross_attentions = all_cross_attentions + (layer_outputs[5],) |
|
|
|
|
|
if self.model_parallel: |
|
for k, v in self.device_map.items(): |
|
if i == v[-1] and "cuda:" + str(k) != self.last_device: |
|
hidden_states = hidden_states.to("cuda:" + str(k + 1)) |
|
|
|
hidden_states = self.final_layer_norm(hidden_states) |
|
hidden_states = self.dropout(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [ |
|
hidden_states, |
|
present_key_value_states, |
|
all_hidden_states, |
|
all_attentions, |
|
all_cross_attentions, |
|
] |
|
if v is not None |
|
) |
|
return modeling_t5.BaseModelOutputWithPastAndCrossAttentions( |
|
last_hidden_state=hidden_states, |
|
past_key_values=present_key_value_states, |
|
hidden_states=all_hidden_states, |
|
attentions=all_attentions, |
|
cross_attentions=all_cross_attentions, |
|
) |
|
|
|
|
|
class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration): |
|
def __init__(self, config: DecoderOnlyT5Config): |
|
super(modeling_t5.T5ForConditionalGeneration, self).__init__(config) |
|
self.model_dim = config.d_model |
|
|
|
self.shared = nn.Embedding(config.vocab_size, config.d_model) |
|
assert ( |
|
self.config.num_layers == 0 |
|
), "Decoder only model cannot have encoder layers" |
|
self.encoder = None |
|
|
|
decoder_config = copy.deepcopy(config) |
|
decoder_config.is_decoder = True |
|
decoder_config.is_encoder_decoder = False |
|
decoder_config.num_layers = config.num_decoder_layers |
|
self.decoder = DecoderOnlyT5Stack(decoder_config, self.shared) |
|
|
|
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
self.model_parallel = False |
|
self.device_map = None |
|
|
|
def _tie_weights(self): |
|
if not self.config.tie_word_embeddings: |
|
return |
|
if self.encoder: |
|
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) |
|
if self.decoder: |
|
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) |
|
|
|
@add_start_docstrings_to_model_forward(modeling_t5.T5_INPUTS_DOCSTRING) |
|
@replace_return_docstrings( |
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC |
|
) |
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., |
|
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for |
|
labels in `[0, ..., config.vocab_size]` |
|
|
|
Returns: |
|
|
|
Examples: |
|
|
|
```""" |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
if self.model_parallel: |
|
torch.cuda.set_device(self.decoder.first_device) |
|
|
|
|
|
if self.model_parallel: |
|
torch.cuda.set_device(self.decoder.first_device) |
|
if input_ids is not None: |
|
input_ids = input_ids.to(self.decoder.first_device) |
|
if attention_mask is not None: |
|
attention_mask = attention_mask.to(self.decoder.first_device) |
|
|
|
|
|
outputs = self.decoder( |
|
input_ids=input_ids, |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
past_key_values=past_key_values, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
head_mask=None, |
|
cross_attn_head_mask=None, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
sequence_output = outputs[0] |
|
|
|
|
|
if self.model_parallel: |
|
torch.cuda.set_device(self.decoder.first_device) |
|
self.lm_head = self.lm_head.to(self.decoder.first_device) |
|
sequence_output = sequence_output.to(self.lm_head.weight.device) |
|
|
|
if self.config.tie_word_embeddings: |
|
|
|
|
|
sequence_output = sequence_output * (self.model_dim**-0.5) |
|
|
|
lm_logits = self.lm_head(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = CrossEntropyLoss(ignore_index=-100) |
|
|
|
labels = labels.to(lm_logits.device) |
|
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
output = (lm_logits,) + outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=lm_logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|