|
|
|
from typing import Optional, Tuple, Union |
|
import math |
|
import copy |
|
import sys |
|
from importlib import import_module |
|
|
|
import torch |
|
from torch import nn, Tensor |
|
import torch.nn.init as init |
|
from torch.nn import functional as F |
|
from transformers.modeling_outputs import CausalLMOutput |
|
from transformers import ( |
|
PreTrainedModel, |
|
PretrainedConfig, |
|
AutoConfig, |
|
AutoModel, |
|
AutoModelForCausalLM, |
|
) |
|
|
|
from transformers.utils import ( |
|
is_flash_attn_2_available, |
|
is_flash_attn_greater_or_equal_2_10, |
|
) |
|
|
|
if is_flash_attn_2_available(): |
|
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func |
|
|
|
|
|
model_type = "walsh-causal-v1" |
|
|
|
class Config(PretrainedConfig): |
|
model_type = model_type |
|
|
|
attribute_map = { |
|
"hidden_size": "d_embed", |
|
} |
|
|
|
def __init__( |
|
|
|
self, |
|
vocab_size=16000, |
|
pad_index=None, |
|
hidden_size=1024, |
|
num_attention_heads=8, |
|
num_hidden_layers=6, |
|
max_sequence_length=2048, |
|
dim_feedforward = 4096, |
|
dropout=0.1, |
|
loss_function = "causal_loss", |
|
|
|
|
|
positional_encoder_cls='.PositionalEncoder', |
|
attention_cls='.CausalSelfAttention', |
|
activation_cls='torch.nn.ReLU', |
|
feedforward_cls='.FeedforwardLayer', |
|
layer_stack_cls='.TransformerLayerStack', |
|
layer_cls='.PostLayerNorm', |
|
transformer_cls='.Transformer', |
|
norm_cls='torch.nn.LayerNorm', |
|
embdding_cls='torch.nn.Embedding', |
|
output_proj_cls='torch.nn.Linear', |
|
|
|
positional_encoder_args={ |
|
'd_model': 1024, |
|
'max_seq_len': 2048, |
|
}, |
|
|
|
|
|
transformer_args=dict(), |
|
attention_args=dict(), |
|
feedforward_args=dict(), |
|
activation_args=dict(), |
|
norm_args={ |
|
'normalized_shape': 1024, |
|
}, |
|
layer_stack_args=dict(), |
|
layer_args=dict(), |
|
embedding_args=dict(), |
|
output_proj_args=dict(), |
|
|
|
**kwargs, |
|
): |
|
self.vocab_size = vocab_size |
|
self.pad_index = pad_index |
|
self.hidden_size = hidden_size |
|
self.num_attention_heads = num_attention_heads |
|
self.num_hidden_layers = num_hidden_layers |
|
self.max_sequence_length = max_sequence_length |
|
self.loss_function = loss_function |
|
|
|
self.dim_feedforward = dim_feedforward |
|
self.dropout = dropout |
|
|
|
self.positional_encoder_cls = positional_encoder_cls |
|
self.attention_cls = attention_cls |
|
self.activation_cls = activation_cls |
|
self.feedforward_cls = feedforward_cls |
|
self.layer_stack_cls = layer_stack_cls |
|
self.layer_cls = layer_cls |
|
self.transformer_cls = transformer_cls |
|
self.norm_cls = norm_cls |
|
self.embdding_cls = embdding_cls |
|
self.output_proj_cls = output_proj_cls |
|
|
|
self.positional_encoder_args = positional_encoder_args |
|
self.transformer_args = transformer_args |
|
self.attention_args = attention_args |
|
self.feedforward_args = feedforward_args |
|
self.activation_args = activation_args |
|
self.norm_args = norm_args |
|
self.layer_stack_args = layer_stack_args |
|
self.layer_args = layer_args |
|
self.embedding_args = embedding_args |
|
self.output_proj_args = output_proj_args |
|
|
|
super().__init__(**kwargs) |
|
|
|
def causal_loss(logits: Tensor, labels: Tensor, input_ids: Tensor, ignore_index=-100) -> Tensor: |
|
""" |
|
Compute and return the loss using logits and labels. |
|
""" |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss = torch.nn.functional.cross_entropy( |
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
shift_labels.view(-1), |
|
ignore_index=ignore_index, |
|
reduction='mean', |
|
) |
|
|
|
return loss.nan_to_num() |
|
|
|
|
|
|
|
def ditto_loss(logits: Tensor, labels: Tensor, input_ids: Tensor) -> Tensor: |
|
batch_size, seq_len, vocab_size = logits.shape |
|
rep_reduce_gamma = 0.5 |
|
ditto_weight = 1.0e5 |
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
total_loss = None |
|
for i in range(batch_size): |
|
context_len = labels[i, 0].item() |
|
sentence_len = labels[i, 1].item() |
|
n_repeats = labels[i, 2].item() |
|
|
|
|
|
context_end = context_len |
|
sentence_start = context_len |
|
sentence_end = sentence_start + sentence_len |
|
target_start = sentence_end |
|
|
|
|
|
causal_ids = input_ids[i:i+1, :context_end] |
|
c_loss = causal_loss( |
|
logits=logits[i:i+1, :context_end], |
|
labels=causal_ids, |
|
input_ids=causal_ids |
|
) |
|
|
|
|
|
target_probs = probs[i , target_start:, :] |
|
|
|
|
|
|
|
baseline_probs = probs[i, sentence_start:sentence_end, :].detach().repeat(n_repeats, 1)[:target_probs.size(0), :] |
|
|
|
|
|
one_minus_probs = torch.clamp((1.0 - torch.abs((target_probs - baseline_probs * rep_reduce_gamma))), min=1e-20) |
|
r_loss = -torch.log(one_minus_probs).mean() * ditto_weight |
|
|
|
|
|
loss = c_loss + r_loss |
|
|
|
|
|
if total_loss is None: |
|
total_loss = loss |
|
else: |
|
total_loss += loss |
|
|
|
return total_loss / batch_size |
|
|
|
|
|
def get_dynamic_class(name): |
|
try: |
|
module_path, class_name = name.rsplit('.', 1) |
|
if module_path == "": |
|
return getattr(sys.modules[__name__], class_name) |
|
module = import_module(module_path) |
|
return getattr(module, class_name) |
|
except (ImportError, AttributeError) as e: |
|
raise ImportError(name) |
|
|
|
|
|
|
|
class HFCausalModel(PreTrainedModel): |
|
config_class = Config |
|
model_type = 'Transformer' |
|
supports_gradient_checkpointing = True |
|
|
|
_no_split_modules = ["DeepNetLayer"] |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.d_model = config.hidden_size |
|
self.transformer_head = self._make_transformer(config) |
|
self.loss_function = get_dynamic_class(config.loss_function) |
|
self.gradient_checkpointing = False |
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs, |
|
) -> (Tensor, dict[str, Tensor]): |
|
|
|
if self.gradient_checkpointing and self.training: |
|
gradient_checkpointing_func = self._gradient_checkpointing_func |
|
else: |
|
gradient_checkpointing_func = None |
|
|
|
logits, attentions = self.transformer_head( |
|
input_ids=input_ids, |
|
need_weights=output_attentions, |
|
gradient_checkpointing_func=gradient_checkpointing_func, |
|
) |
|
|
|
|
|
if labels is not None: |
|
loss = self.loss_function(logits=logits, labels=labels, input_ids=input_ids) |
|
else: |
|
loss = None |
|
|
|
return CausalLMOutput(loss=loss, logits=logits, attentions=attentions) |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, **kwargs): |
|
attention_mask = kwargs.get("attention_mask", None) |
|
model_inputs = { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
} |
|
return model_inputs |
|
|
|
def _make_embedding(self, config): |
|
embedding_cls = get_dynamic_class(config.embdding_cls) |
|
return embedding_cls(config.vocab_size, self.d_model, config.pad_index, **config.embedding_args) |
|
|
|
def _make_pos_encoder(self, config): |
|
pos_enc_cls = get_dynamic_class(config.positional_encoder_cls) |
|
return pos_enc_cls(**config.positional_encoder_args) |
|
|
|
def _make_output_projection(self, config): |
|
output_proj_cls = get_dynamic_class(config.output_proj_cls) |
|
return output_proj_cls(self.d_model, config.vocab_size, **config.output_proj_args) |
|
|
|
def _make_dropout(self, config): |
|
return nn.Dropout(config.dropout) |
|
|
|
def _make_activation(self, config): |
|
activation_cls = get_dynamic_class(config.activation_cls) |
|
return activation_cls(**config.activation_args) |
|
|
|
def _make_norm(self, config): |
|
norm_cls = get_dynamic_class(config.norm_cls) |
|
return norm_cls(self.d_model) |
|
|
|
def _make_self_attention(self, config): |
|
attention_cls = get_dynamic_class(config.attention_cls) |
|
|
|
match config._attn_implementation: |
|
case "flash_attention_2": |
|
if is_flash_attn_2_available(): |
|
if not is_flash_attn_greater_or_equal_2_10(): |
|
raise Exception("flash_attn_2 >= 2.10 is required") |
|
attn_type = "flash2" |
|
else: |
|
attn_type = "torch" |
|
case "sdpa": |
|
attn_type = "torch" |
|
case "eager": |
|
attn_type = "native" |
|
case _: |
|
raise Exception(f"Unimplemented attention type '{config._attn_implementation}'") |
|
return attention_cls( |
|
d_model=self.d_model, |
|
num_heads=config.num_attention_heads, |
|
attn_type=attn_type, |
|
**config.attention_args, |
|
) |
|
|
|
def _make_feedforward(self, config): |
|
feedforward_cls = get_dynamic_class(config.feedforward_cls) |
|
return feedforward_cls( |
|
d_model=self.d_model, |
|
feedforward_dim=config.dim_feedforward, |
|
dropout=config.dropout, |
|
activation=self._make_activation(config), |
|
**config.feedforward_args, |
|
) |
|
|
|
def _make_layer(self, config): |
|
layer_cls = get_dynamic_class(config.layer_cls) |
|
return layer_cls( |
|
d_model=self.d_model, |
|
dropout=self._make_dropout(config), |
|
attention=self._make_self_attention(config), |
|
feedforward=self._make_feedforward(config), |
|
norm1=self._make_norm(config), |
|
norm2=self._make_norm(config), |
|
**config.layer_args, |
|
) |
|
|
|
def _make_layer_stack(self, config): |
|
layer_stack_cls = get_dynamic_class(config.layer_stack_cls) |
|
return layer_stack_cls( |
|
layers=nn.ModuleList([ |
|
self._make_layer(config) for _ in range(config.num_hidden_layers) |
|
]), |
|
**config.layer_stack_args, |
|
) |
|
|
|
def _make_transformer(self, config): |
|
transformer_cls = get_dynamic_class(config.transformer_cls) |
|
return transformer_cls( |
|
d_model=self.d_model, |
|
embedding=self._make_embedding(config), |
|
positional_encoder=self._make_pos_encoder(config), |
|
layer_stack=self._make_layer_stack(config), |
|
output_projection=self._make_output_projection(config), |
|
**config.transformer_args, |
|
) |
|
|
|
@torch.no_grad() |
|
def _init_weights(self, module): |
|
pass |
|
|
|
|
|
AutoConfig.register(model_type, Config) |
|
AutoModelForCausalLM.register(Config, HFCausalModel) |
|
|
|
|
|
class Transformer(nn.Module): |
|
def __init__(self, d_model, embedding, positional_encoder, layer_stack, output_projection, **kwargs): |
|
super().__init__() |
|
self.embedding = embedding |
|
self.positional_encoder = positional_encoder |
|
self.layer_stack = layer_stack |
|
self.output_projection = output_projection |
|
self.d_model = d_model |
|
self.sqrt_d_model = d_model**0.5 |
|
self.reset_parameters() |
|
|
|
def forward(self, input_ids, need_weights, gradient_checkpointing_func): |
|
x = self.positional_encoder(self.embedding(input_ids) * self.sqrt_d_model) |
|
|
|
x, attentions = self.layer_stack( |
|
x, |
|
need_weights, |
|
gradient_checkpointing_func, |
|
) |
|
|
|
|
|
logits = self.output_projection(x) |
|
return logits, attentions |
|
|
|
def reset_parameters(self): |
|
init.xavier_uniform_(self.output_projection.weight) |
|
init.constant_(self.output_projection.bias, 0.) |
|
init.normal_(self.embedding.weight, std=self.d_model**-0.5) |
|
|
|
|
|
class PositionalEncoder(nn.Module): |
|
def __init__(self, d_embed, max_seq): |
|
super().__init__() |
|
self.d_embed = d_embed |
|
self.max_seq = max_seq |
|
|
|
weight = torch.zeros(max_seq, d_embed) |
|
position = torch.arange(0, max_seq, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp(torch.arange(0, d_embed, 2).float() * (-math.log(10000.0) / d_embed)) |
|
weight[:, 0::2] = torch.sin(position * div_term) |
|
weight[:, 1::2] = torch.cos(position * div_term) |
|
weight = weight.unsqueeze(0) |
|
self.register_buffer('weight', weight) |
|
|
|
def forward(self, x): |
|
seq_len = x.size(-2) |
|
return x + self.weight[:, :seq_len] |
|
|
|
|
|
def binary_tensor(x, bits): |
|
mask = 2**torch.arange(bits).to(x.device, x.dtype) |
|
return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte() |
|
|
|
def hadamard_walsh_matrix(k: int): |
|
|
|
assert k > 0 |
|
|
|
|
|
h1 = torch.tensor([[1, 1], [1, -1]], dtype=torch.float) |
|
|
|
|
|
|
|
|
|
|
|
w = h1 |
|
for _ in range(k-1): |
|
w = torch.kron(h1, w) |
|
|
|
return w |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RSWalshPositionalEncoder(nn.Module): |
|
def __init__(self, d_embed, max_seq, gain=0.333): |
|
super().__init__() |
|
self.max_seq = max_seq |
|
self.d_embed = d_embed |
|
|
|
|
|
k = math.ceil(math.log2(d_embed)) |
|
|
|
|
|
bits = math.ceil(math.log2(max_seq)) |
|
|
|
|
|
|
|
self.gain = gain |
|
|
|
assert bits <= d_embed, "max_seq exceeds n-bits available for d_embed" |
|
|
|
|
|
|
|
|
|
|
|
|
|
binary_code = binary_tensor(torch.arange(0, max_seq, 1), bits) |
|
self.register_buffer('binary_code', binary_code, persistent=False) |
|
|
|
|
|
|
|
|
|
walsh = hadamard_walsh_matrix(k)[:bits,:d_embed] * self.gain |
|
|
|
|
|
|
|
|
|
self.register_buffer('walsh', walsh, persistent=False) |
|
|
|
def forward(self, x): |
|
seq_len = x.size(-2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.training: |
|
shift = torch.randint(self.max_seq - seq_len + 1, (1,)).item() |
|
seq = self.binary_code[shift:seq_len + shift,:] |
|
|
|
|
|
|
|
else: |
|
seq = self.binary_code[:seq_len,:] |
|
|
|
|
|
|
|
|
|
self.walsh = self.walsh.to(dtype=x.dtype) |
|
|
|
|
|
|
|
|
|
|
|
return x + (seq.to(dtype=x.dtype) @ self.walsh) |
|
|
|
|
|
class TransformerLayerStack(nn.Module): |
|
def __init__(self, layers): |
|
super().__init__() |
|
self.layers = layers |
|
|
|
def forward(self, x, need_weights, gradient_checkpointing_func=None): |
|
attentions = [] |
|
for layer in self.layers: |
|
if gradient_checkpointing_func is not None: |
|
x, attention_weights = gradient_checkpointing_func( |
|
layer.__call__, |
|
x, |
|
need_weights, |
|
use_reentrant=False |
|
) |
|
else: |
|
x, attention_weights = layer(x, need_weights=need_weights) |
|
if need_weights: |
|
attentions.append(attention_weights) |
|
|
|
return x, attentions |
|
|
|
|
|
|
|
class DeepnetLayer(nn.Module): |
|
def __init__( |
|
self, |
|
d_model, |
|
attention, |
|
feedforward, |
|
norm1, |
|
norm2, |
|
dropout, |
|
alpha=1.0, |
|
): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.attention = attention |
|
self.feedforward = feedforward |
|
self.norm1 = norm1 |
|
self.norm2 = norm2 |
|
self.dropout = dropout |
|
|
|
self.alpha = alpha |
|
|
|
def forward(self, x, need_weights=False): |
|
|
|
residual = x * self.alpha |
|
|
|
|
|
x, attention_weights = self.attention(x, need_weights) |
|
|
|
|
|
x = self.norm1(residual + self.dropout(x)) |
|
|
|
|
|
residual = x * self.alpha |
|
|
|
|
|
x = self.feedforward(x) |
|
|
|
|
|
x = self.norm2(residual + self.dropout(x)) |
|
|
|
return x, attention_weights |
|
|
|
|
|
class FeedforwardLayer(nn.Module): |
|
def __init__( |
|
self, |
|
d_model: int, |
|
feedforward_dim: int, |
|
dropout, |
|
activation=nn.ReLU(), |
|
beta=1.0, |
|
bias=True, |
|
): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.beta = beta |
|
self.activation = activation |
|
self.linear1 = nn.Linear(d_model, feedforward_dim, bias=bias) |
|
self.linear2 = nn.Linear(feedforward_dim, d_model, bias=bias) |
|
self.dropout = nn.Dropout(dropout) |
|
self.reset_parameters() |
|
|
|
def forward(self, x): |
|
return self.linear2(self.dropout(self.activation(self.linear1(x)))) |
|
|
|
def reset_parameters(self): |
|
init.xavier_uniform_(self.linear1.weight, gain=self.beta) |
|
init.xavier_uniform_(self.linear2.weight, gain=self.beta) |
|
init.constant_(self.linear1.bias, 0.) |
|
init.constant_(self.linear2.bias, 0.) |
|
|
|
|
|
|
|
class SwiGLUFeedforwardLayer(nn.Module): |
|
def __init__( |
|
self, |
|
d_model, |
|
d_feedforward, |
|
beta=1.0, |
|
dropout=0.1 |
|
): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.d_feedforward = d_feedforward |
|
self.beta = 1.0 |
|
|
|
self.linear1 = nn.Linear(self.d_model, self.d_feedforward * 2, bias=False) |
|
self.linear2 = nn.Linear(self.d_feedforward, self.d_model, bias=False) |
|
self.dropout = nn.Dropout(dropout) |
|
self.reset_parameters() |
|
|
|
def forward(self, x): |
|
x, gate = self.linear1(x).chunk(2, dim=-1) |
|
x = x * F.silu(gate) |
|
x = self.dropout(x) |
|
x = self.linear2(x) |
|
return x |
|
|
|
def reset_parameters(self): |
|
|
|
|
|
w, g = self.linear1.weight.chunk(2, dim=0) |
|
init.xavier_uniform_(w, gain=self.beta) |
|
init.xavier_uniform_(g, gain=self.beta) |
|
init.xavier_uniform_(self.linear2.weight, gain=self.beta) |
|
|
|
class CausalSelfAttention(nn.Module): |
|
def __init__( |
|
self, |
|
d_model, |
|
num_heads, |
|
|
|
|
|
|
|
|
|
attn_type, |
|
beta=1.0, |
|
dropout=0.1, |
|
): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.num_heads = num_heads |
|
self.beta = beta |
|
self.attn_type = attn_type |
|
|
|
assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads" |
|
|
|
|
|
self.d_head = d_model // num_heads |
|
|
|
|
|
|
|
self.dot_product_scale = 1.0 / math.sqrt(self.d_head) |
|
|
|
self.in_proj = nn.Linear(self.d_model, 3 * self.d_model, bias=True) |
|
self.output_linear = nn.Linear(self.d_model, self.d_model, bias=True) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.reset_parameters() |
|
|
|
def extra_repr(self) -> str: |
|
return f'd_model={self.d_model}, num_heads={self.num_heads}, beta={self.beta}, attn_type={self.attn_type}, dropout={self.dropout}' |
|
|
|
def reset_parameters(self): |
|
|
|
|
|
q, k, v = self.in_proj.weight.chunk(3) |
|
init.xavier_uniform_(q, gain=1.0) |
|
init.xavier_uniform_(k, gain=1.0) |
|
init.xavier_uniform_(v, gain=self.beta) |
|
init.xavier_uniform_(self.output_linear.weight, gain=self.beta) |
|
init.constant_(self.in_proj.bias, 0.) |
|
init.constant_(self.output_linear.bias, 0.) |
|
|
|
def project_input(self, qkv): |
|
proj = self.in_proj(qkv) |
|
return proj.chunk(chunks=3, dim=-1) |
|
|
|
def forward(self, qkv, need_weights): |
|
if self.attn_type == "flash2": |
|
return self.flash2_forward(qkv) |
|
|
|
|
|
batch_size, seq_len, d_embed = qkv.shape |
|
|
|
|
|
query, key, value = self.project_input(qkv) |
|
|
|
|
|
query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2) |
|
key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2) |
|
value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2) |
|
|
|
|
|
attention_weights = None |
|
|
|
if self.attn_type == "torch": |
|
|
|
|
|
attended_values = F.scaled_dot_product_attention( |
|
query, |
|
key, |
|
value, |
|
attn_mask=None, |
|
dropout_p=self.dropout.p if self.training else 0.0, |
|
is_causal=True, |
|
scale=self.dot_product_scale |
|
) |
|
|
|
else: |
|
|
|
scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale |
|
|
|
|
|
scores.masked_fill_( |
|
torch.tril( |
|
torch.ones(seq_len, seq_len, dtype=torch.bool, device=qkv.device), |
|
diagonal=0, |
|
).logical_not(), |
|
float('-inf'), |
|
) |
|
|
|
|
|
attention_weights = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10)) |
|
del scores |
|
|
|
|
|
attended_values = torch.matmul(attention_weights, value) |
|
if not need_weights: |
|
del attention_weights |
|
attention_weights = None |
|
|
|
|
|
attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed) |
|
|
|
|
|
attended_values = self.output_linear(attended_values) |
|
return attended_values, attention_weights |
|
|
|
def flash2_forward(self, qkv): |
|
batch_size, seq_len, d_embed = qkv.shape |
|
|
|
|
|
|
|
|
|
qkv = self.in_proj(qkv).unflatten( |
|
-1, |
|
(3, self.num_heads, self.d_head) |
|
) |
|
|
|
attended_values = flash_attn_qkvpacked_func( |
|
qkv.bfloat16(), |
|
dropout_p=self.dropout.p if self.training else 0.0, |
|
softmax_scale=self.dot_product_scale, |
|
causal=True, |
|
) |
|
|
|
|
|
|
|
attended_values = attended_values.view(batch_size, seq_len, d_embed) |
|
|
|
|
|
attended_values = self.output_linear(attended_values) |
|
return attended_values, None |
|
|
|
|
|
|
|
|
|
def alibi_biases(query_len, key_len, device='cpu'): |
|
x = torch.arange(key_len, device=device)[None, :] |
|
y = torch.arange(query_len, device=device)[:, None] |
|
return x - y |
|
|
|
class CausalAlibiAttention(nn.Module): |
|
def __init__( |
|
self, |
|
d_model, |
|
num_heads, |
|
beta=1.0, |
|
dropout=0.1, |
|
|
|
|
|
|
|
|
|
|
|
window_size=None, |
|
attn_type="native", |
|
freeze_alibi=True, |
|
): |
|
super().__init__() |
|
self.d_model = d_model |
|
self.num_heads = num_heads |
|
self.beta = beta |
|
self.attn_type = attn_type |
|
|
|
assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads" |
|
|
|
|
|
self.d_head = d_model // num_heads |
|
|
|
|
|
|
|
self.dot_product_scale = 1.0 / math.sqrt(self.d_head) |
|
|
|
self.in_proj = nn.Parameter(torch.empty(3 * self.d_model, self.d_model)) |
|
self.output_linear = nn.Linear(self.d_model, self.d_model, bias=False) |
|
|
|
if window_size is not None: |
|
self.window_size=(window_size, -1) |
|
else: |
|
self.window_size = (-1, -1) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
alibi_slopes = 1.0 / torch.logspace(0, 7, self.num_heads, base=2, dtype=torch.float) |
|
|
|
|
|
|
|
|
|
|
|
self.alibi_slopes = nn.Parameter(alibi_slopes) |
|
|
|
|
|
self.alibi_slopes.requires_grad = (not freeze_alibi) |
|
self.reset_parameters() |
|
|
|
def extra_repr(self) -> str: |
|
return f'd_model={self.d_model}, num_heads={self.num_heads}, beta={self.beta}, attn_type={self.attn_type}, window_size={self.window_size}, dropout={self.dropout}' |
|
|
|
def reset_parameters(self): |
|
|
|
|
|
|
|
q, k, v = self.in_proj.chunk(3) |
|
init.xavier_uniform_(q, gain=1.0) |
|
init.xavier_uniform_(k, gain=1.0) |
|
init.xavier_uniform_(v, gain=self.beta) |
|
init.xavier_uniform_(self.output_linear.weight, gain=self.beta) |
|
|
|
def project_input(self, qkv): |
|
proj = F.linear(qkv, self.in_proj) |
|
return proj.chunk(chunks=3, dim=-1) |
|
|
|
def forward(self, qkv, need_weights): |
|
if self.attn_type == "flash2": |
|
return self.flash2_forward(qkv) |
|
|
|
|
|
batch_size, seq_len, d_embed = qkv.shape |
|
|
|
|
|
query, key, value = self.project_input(qkv) |
|
|
|
|
|
query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2) |
|
key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2) |
|
value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2) |
|
|
|
|
|
attn_bias = alibi_biases(seq_len, seq_len, device=query.device) * self.alibi_slopes.view(-1, 1, 1) |
|
|
|
|
|
causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=qkv.device), diagonal=0) |
|
attn_bias.masked_fill_(causal_mask.logical_not(), float('-inf')) |
|
del causal_mask |
|
|
|
|
|
attention_weights = None |
|
|
|
if self.attn_type == "torch": |
|
|
|
|
|
attended_values = F.scaled_dot_product_attention( |
|
query, |
|
key, |
|
value, |
|
attn_mask=attn_bias.to(dtype=query.dtype), |
|
dropout_p=self.dropout.p if self.training else 0.0, |
|
is_causal=False, |
|
scale=self.dot_product_scale |
|
) |
|
|
|
else: |
|
|
|
scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale |
|
|
|
|
|
scores += attn_bias |
|
|
|
|
|
attention_weights = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10)) |
|
|
|
|
|
attended_values = torch.matmul(attention_weights, value) |
|
if not need_weights: |
|
attention_weights = None |
|
|
|
|
|
attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed) |
|
|
|
|
|
attended_values = self.output_linear(attended_values) |
|
return attended_values, attention_weights |
|
|
|
def flash2_forward(self, qkv): |
|
batch_size, seq_len, d_embed = qkv.shape |
|
|
|
|
|
|
|
|
|
qkv = F.linear( |
|
qkv, |
|
self.in_proj, |
|
).unflatten( |
|
-1, |
|
(3, self.num_heads, self.d_head) |
|
) |
|
|
|
attended_values = flash_attn_qkvpacked_func( |
|
qkv.bfloat16(), |
|
dropout_p=self.dropout.p if self.training else 0.0, |
|
softmax_scale=self.dot_product_scale, |
|
causal=True, |
|
window_size=self.window_size, |
|
alibi_slopes=self.alibi_slopes.float(), |
|
).to(dtype=qkv.dtype) |
|
|
|
|
|
|
|
attended_values = attended_values.view(batch_size, seq_len, d_embed) |
|
|
|
|
|
attended_values = self.output_linear(attended_values) |
|
return attended_values, None |