walsh-1-7b / modelling_walsh.py
dinalt's picture
Upload model
b9012bf verified
raw
history blame
37.1 kB
# See: https://huggingface.co/docs/transformers/custom_models
from typing import Optional, Tuple, Union
import math
import copy
import sys
from importlib import import_module
import torch
from torch import nn, Tensor
import torch.nn.init as init
from torch.nn import functional as F
from transformers.modeling_outputs import CausalLMOutput
from transformers import (
PreTrainedModel,
PretrainedConfig,
AutoConfig,
AutoModel,
AutoModelForCausalLM,
)
from transformers.utils import (
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
)
if is_flash_attn_2_available():
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
# The model type string to bind.
model_type = "walsh-causal-v1"
class Config(PretrainedConfig):
model_type = model_type
attribute_map = {
"hidden_size": "d_embed",
}
def __init__(
# All of these MUST have defaults, even if unused.
self,
vocab_size=16000,
pad_index=None,
hidden_size=1024,
num_attention_heads=8,
num_hidden_layers=6,
max_sequence_length=2048,
dim_feedforward = 4096,
dropout=0.1,
loss_function = "causal_loss",
# Default class to use for each of these components.
positional_encoder_cls='.PositionalEncoder',
attention_cls='.CausalSelfAttention',
activation_cls='torch.nn.ReLU',
feedforward_cls='.FeedforwardLayer',
layer_stack_cls='.TransformerLayerStack',
layer_cls='.PostLayerNorm',
transformer_cls='.Transformer',
norm_cls='torch.nn.LayerNorm',
embdding_cls='torch.nn.Embedding',
output_proj_cls='torch.nn.Linear',
positional_encoder_args={
'd_model': 1024,
'max_seq_len': 2048,
},
# Arg groups, passed to factory classes above.
transformer_args=dict(),
attention_args=dict(),
feedforward_args=dict(),
activation_args=dict(),
norm_args={
'normalized_shape': 1024,
},
layer_stack_args=dict(),
layer_args=dict(),
embedding_args=dict(),
output_proj_args=dict(),
**kwargs,
):
self.vocab_size = vocab_size
self.pad_index = pad_index
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.max_sequence_length = max_sequence_length
self.loss_function = loss_function
self.dim_feedforward = dim_feedforward
self.dropout = dropout
self.positional_encoder_cls = positional_encoder_cls
self.attention_cls = attention_cls
self.activation_cls = activation_cls
self.feedforward_cls = feedforward_cls
self.layer_stack_cls = layer_stack_cls
self.layer_cls = layer_cls
self.transformer_cls = transformer_cls
self.norm_cls = norm_cls
self.embdding_cls = embdding_cls
self.output_proj_cls = output_proj_cls
self.positional_encoder_args = positional_encoder_args
self.transformer_args = transformer_args
self.attention_args = attention_args
self.feedforward_args = feedforward_args
self.activation_args = activation_args
self.norm_args = norm_args
self.layer_stack_args = layer_stack_args
self.layer_args = layer_args
self.embedding_args = embedding_args
self.output_proj_args = output_proj_args
super().__init__(**kwargs)
def causal_loss(logits: Tensor, labels: Tensor, input_ids: Tensor, ignore_index=-100) -> Tensor:
"""
Compute and return the loss using logits and labels.
"""
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=ignore_index,
reduction='mean',
)
return loss.nan_to_num()
# Learning to Break the Loop: Analyzing and Mitigating Repetitions for Neural Text Generation
# https://arxiv.org/abs/2206.02369
def ditto_loss(logits: Tensor, labels: Tensor, input_ids: Tensor) -> Tensor:
batch_size, seq_len, vocab_size = logits.shape
rep_reduce_gamma = 0.5
ditto_weight = 1.0e5
probs = torch.softmax(logits, dim=-1)
total_loss = None
for i in range(batch_size):
context_len = labels[i, 0].item()
sentence_len = labels[i, 1].item()
n_repeats = labels[i, 2].item()
# For readability
context_end = context_len
sentence_start = context_len
sentence_end = sentence_start + sentence_len
target_start = sentence_end
# Get causal loss for context tokens
causal_ids = input_ids[i:i+1, :context_end]
c_loss = causal_loss(
logits=logits[i:i+1, :context_end],
labels=causal_ids,
input_ids=causal_ids
)
# Slice out target probabilities
target_probs = probs[i , target_start:, :]
# Slice out first instance of repeated sentence, detach is (prevents back-prop), repeat in N times,
# and trim to length of target_probs.
baseline_probs = probs[i, sentence_start:sentence_end, :].detach().repeat(n_repeats, 1)[:target_probs.size(0), :]
# Compute DITTO loss.
one_minus_probs = torch.clamp((1.0 - torch.abs((target_probs - baseline_probs * rep_reduce_gamma))), min=1e-20)
r_loss = -torch.log(one_minus_probs).mean() * ditto_weight
# Combine repitition and causal loss
loss = c_loss + r_loss
# Add this to the total
if total_loss is None:
total_loss = loss
else:
total_loss += loss
return total_loss / batch_size
# Dynamically lookup class name and return factory for class.
def get_dynamic_class(name):
try:
module_path, class_name = name.rsplit('.', 1)
if module_path == "":
return getattr(sys.modules[__name__], class_name)
module = import_module(module_path)
return getattr(module, class_name)
except (ImportError, AttributeError) as e:
raise ImportError(name)
# An easily extensible dynamic transformer class
# Many variations can be specified entirely in the configuration, without touching this code.
class HFCausalModel(PreTrainedModel):
config_class = Config
model_type = 'Transformer'
supports_gradient_checkpointing = True
# Presently needs to be manually set to match transformer layer class...
_no_split_modules = ["DeepNetLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config):
super().__init__(config)
self.d_model = config.hidden_size
self.transformer_head = self._make_transformer(config)
self.loss_function = get_dynamic_class(config.loss_function)
self.gradient_checkpointing = False
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> (Tensor, dict[str, Tensor]):
if self.gradient_checkpointing and self.training:
gradient_checkpointing_func = self._gradient_checkpointing_func
else:
gradient_checkpointing_func = None
logits, attentions = self.transformer_head(
input_ids=input_ids,
need_weights=output_attentions,
gradient_checkpointing_func=gradient_checkpointing_func,
)
# Compute loss.
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, input_ids=input_ids)
else:
loss = None
return CausalLMOutput(loss=loss, logits=logits, attentions=attentions)
# Needed for generate() method.
def prepare_inputs_for_generation(self, input_ids, **kwargs):
attention_mask = kwargs.get("attention_mask", None)
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return model_inputs
def _make_embedding(self, config):
embedding_cls = get_dynamic_class(config.embdding_cls)
return embedding_cls(config.vocab_size, self.d_model, config.pad_index, **config.embedding_args)
def _make_pos_encoder(self, config):
pos_enc_cls = get_dynamic_class(config.positional_encoder_cls)
return pos_enc_cls(**config.positional_encoder_args)
def _make_output_projection(self, config):
output_proj_cls = get_dynamic_class(config.output_proj_cls)
return output_proj_cls(self.d_model, config.vocab_size, **config.output_proj_args)
def _make_dropout(self, config):
return nn.Dropout(config.dropout)
def _make_activation(self, config):
activation_cls = get_dynamic_class(config.activation_cls)
return activation_cls(**config.activation_args)
def _make_norm(self, config):
norm_cls = get_dynamic_class(config.norm_cls)
return norm_cls(self.d_model)
def _make_self_attention(self, config):
attention_cls = get_dynamic_class(config.attention_cls)
# Map HF _attn_implementation to attn_type
match config._attn_implementation:
case "flash_attention_2":
if is_flash_attn_2_available():
if not is_flash_attn_greater_or_equal_2_10():
raise Exception("flash_attn_2 >= 2.10 is required")
attn_type = "flash2"
else:
attn_type = "torch"
case "sdpa":
attn_type = "torch"
case "eager":
attn_type = "native"
case _:
raise Exception(f"Unimplemented attention type '{config._attn_implementation}'")
return attention_cls(
d_model=self.d_model,
num_heads=config.num_attention_heads,
attn_type=attn_type,
**config.attention_args,
)
def _make_feedforward(self, config):
feedforward_cls = get_dynamic_class(config.feedforward_cls)
return feedforward_cls(
d_model=self.d_model,
feedforward_dim=config.dim_feedforward,
dropout=config.dropout,
activation=self._make_activation(config),
**config.feedforward_args,
)
def _make_layer(self, config):
layer_cls = get_dynamic_class(config.layer_cls)
return layer_cls(
d_model=self.d_model,
dropout=self._make_dropout(config),
attention=self._make_self_attention(config),
feedforward=self._make_feedforward(config),
norm1=self._make_norm(config),
norm2=self._make_norm(config),
**config.layer_args,
)
def _make_layer_stack(self, config):
layer_stack_cls = get_dynamic_class(config.layer_stack_cls)
return layer_stack_cls(
layers=nn.ModuleList([
self._make_layer(config) for _ in range(config.num_hidden_layers)
]),
**config.layer_stack_args,
)
def _make_transformer(self, config):
transformer_cls = get_dynamic_class(config.transformer_cls)
return transformer_cls(
d_model=self.d_model,
embedding=self._make_embedding(config),
positional_encoder=self._make_pos_encoder(config),
layer_stack=self._make_layer_stack(config),
output_projection=self._make_output_projection(config),
**config.transformer_args,
)
@torch.no_grad()
def _init_weights(self, module):
pass
# Register model type and configuration
AutoConfig.register(model_type, Config)
AutoModelForCausalLM.register(Config, HFCausalModel)
# A generic container class for standard transformer components.
class Transformer(nn.Module):
def __init__(self, d_model, embedding, positional_encoder, layer_stack, output_projection, **kwargs):
super().__init__()
self.embedding = embedding
self.positional_encoder = positional_encoder
self.layer_stack = layer_stack
self.output_projection = output_projection
self.d_model = d_model
self.sqrt_d_model = d_model**0.5
self.reset_parameters()
def forward(self, input_ids, need_weights, gradient_checkpointing_func):
x = self.positional_encoder(self.embedding(input_ids) * self.sqrt_d_model)
x, attentions = self.layer_stack(
x,
need_weights,
gradient_checkpointing_func,
)
# Translate output embedding ot logits.
logits = self.output_projection(x)
return logits, attentions
def reset_parameters(self):
init.xavier_uniform_(self.output_projection.weight)
init.constant_(self.output_projection.bias, 0.)
init.normal_(self.embedding.weight, std=self.d_model**-0.5)
# A vanilla positional encoder
class PositionalEncoder(nn.Module):
def __init__(self, d_embed, max_seq):
super().__init__()
self.d_embed = d_embed
self.max_seq = max_seq
weight = torch.zeros(max_seq, d_embed)
position = torch.arange(0, max_seq, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_embed, 2).float() * (-math.log(10000.0) / d_embed))
weight[:, 0::2] = torch.sin(position * div_term)
weight[:, 1::2] = torch.cos(position * div_term)
weight = weight.unsqueeze(0)
self.register_buffer('weight', weight)
def forward(self, x):
seq_len = x.size(-2)
return x + self.weight[:, :seq_len]
# Converts a torch array of integers into their equivalent binary codes.
def binary_tensor(x, bits):
mask = 2**torch.arange(bits).to(x.device, x.dtype)
return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()
def hadamard_walsh_matrix(k: int):
# k: The dimension of the matrix is 2^k
assert k > 0
# Start with Hadamard H2^1 matrix.
h1 = torch.tensor([[1, 1], [1, -1]], dtype=torch.float)
# The series of matrices can be computed by recurisvely applying the Kronecker product,
# starting with h1.
#
# This will produce the series of Hadamard-Wlash matrices in natural order.
w = h1
for _ in range(k-1):
w = torch.kron(h1, w)
return w
# This positional encoder adds absolute binary positions to the embedding, encoded via
# Hadamard-Walsh matrix.
# See: https://en.wikipedia.org/wiki/Hadamard_code
# Each bit in the binary code word is encoded via a row the Hadamard-Walsh matrix, with a
# 1 being encoded by the presense of the row and a 0 by its absence. While training, the base
# sequence offset is randomly selected, which appears to allow the model to generalize to
# sequences longer than it was trained on. This is similar to what is described here:
# https://arxiv.org/pdf/2305.16843.pdf
# I have tried this approach and found that my approach works better for generalization.
#
# Note: Without random shifting, the early performance of this encoder is exceptionally good.
# The drawback is that the model can't generalize to longer sequences than it was trained on
# and can't easily accomidate additonal bits later in the training process.
class RSWalshPositionalEncoder(nn.Module):
def __init__(self, d_embed, max_seq, gain=0.333):
super().__init__()
self.max_seq = max_seq
self.d_embed = d_embed
# Hadamard-Walsh k, where the dimension of the matrix is 2^k
k = math.ceil(math.log2(d_embed))
# The number of bits required to encode max_seq
bits = math.ceil(math.log2(max_seq))
# Gain controls the weight given to the encodings.
# When a trainable parameter, the value appears to settle at around 0.333.
self.gain = gain
assert bits <= d_embed, "max_seq exceeds n-bits available for d_embed"
# Generate sequential binary codes for absolute positionals.
# The implementation originally used Grey codes, which where successive symbols
# differ by by only one bit. See: https://en.wikipedia.org/wiki/Gray_code
# This, along with a few other coding schemes were tested, with a simple
# binary code having the best performance.
binary_code = binary_tensor(torch.arange(0, max_seq, 1), bits)
self.register_buffer('binary_code', binary_code, persistent=False)
# Each bit is encoded via a row of a Hadamard-Walsh matrix.
# We slice off the unused rows and columns -- ideally, d_embed should be
# the same dimension as the matrix.
walsh = hadamard_walsh_matrix(k)[:bits,:d_embed] * self.gain
# This alternative appears superior to the original.
# If starting from scratch, this use this.
# walsh = (hadamard_walsh_matrix(k)[:bits,:d_embed] -0.5) * self.gain
self.register_buffer('walsh', walsh, persistent=False)
def forward(self, x):
seq_len = x.size(-2)
# Get sequence of binary codes...
# We use a random base offset when training.
# This results in slower initial gains, but appears to allow the model to generalize to
# the value of max_seq, even if never trained with sequences of this length. I also have
# a suspicion that this has a regularizing effect on training, similar to dropout. Models with
# random base offset shifting, despite slower initial improvement, appear to perform better in the long-run.
# TODO: Setup a controlled experiment to test this hypothesis.
if self.training:
shift = torch.randint(self.max_seq - seq_len + 1, (1,)).item()
seq = self.binary_code[shift:seq_len + shift,:]
# Disable shifting when not training. This does not appear to change the evaluation loss, but
# it does makes predictions easier to analyse when the attention weights are not shifting with each step.
else:
seq = self.binary_code[:seq_len,:]
# For reasons I have yet to identify, when the model is running in Textgenwebui, the matrix appears
# to evade conversion to bfloat16, despite everything else having been converted.
# This is a work-around for this.
self.walsh = self.walsh.to(dtype=x.dtype)
# Encode binary sequence with Hadamard-Walsh codes and apply to embeddings.
# If nothing else, the Walsh encodings make the positional information exceptionally
# robust with respect to dropout and other adversities. They can still be easily detected
# at the final layer.
return x + (seq.to(dtype=x.dtype) @ self.walsh)
# A generic stack of transformer layers.
class TransformerLayerStack(nn.Module):
def __init__(self, layers):
super().__init__()
self.layers = layers
def forward(self, x, need_weights, gradient_checkpointing_func=None):
attentions = []
for layer in self.layers:
if gradient_checkpointing_func is not None:
x, attention_weights = gradient_checkpointing_func(
layer.__call__,
x,
need_weights,
use_reentrant=False
)
else:
x, attention_weights = layer(x, need_weights=need_weights)
if need_weights:
attentions.append(attention_weights)
return x, attentions
# DeepNet: Scaling Transformers to 1,000 Layers
# https://arxiv.org/abs/2203.00555
class DeepnetLayer(nn.Module):
def __init__(
self,
d_model,
attention,
feedforward,
norm1,
norm2,
dropout,
alpha=1.0,
):
super().__init__()
self.d_model = d_model
self.attention = attention
self.feedforward = feedforward
self.norm1 = norm1
self.norm2 = norm2
self.dropout = dropout
# Deepnet alpha
self.alpha = alpha
def forward(self, x, need_weights=False):
# Keep input as residual
residual = x * self.alpha
# Compute attention
x, attention_weights = self.attention(x, need_weights)
# Add attention with residual and normalize.
x = self.norm1(residual + self.dropout(x))
# Keep output as next residual.
residual = x * self.alpha
# Pass through feedforward network.
x = self.feedforward(x)
# Combine residual and ff output, then normalize again.
x = self.norm2(residual + self.dropout(x))
return x, attention_weights
# A vanilla MLP transfomer layer.
class FeedforwardLayer(nn.Module):
def __init__(
self,
d_model: int,
feedforward_dim: int,
dropout,
activation=nn.ReLU(),
beta=1.0,
bias=True,
):
super().__init__()
self.d_model = d_model
self.beta = beta
self.activation = activation
self.linear1 = nn.Linear(d_model, feedforward_dim, bias=bias)
self.linear2 = nn.Linear(feedforward_dim, d_model, bias=bias)
self.dropout = nn.Dropout(dropout)
self.reset_parameters()
def forward(self, x):
return self.linear2(self.dropout(self.activation(self.linear1(x))))
def reset_parameters(self):
init.xavier_uniform_(self.linear1.weight, gain=self.beta)
init.xavier_uniform_(self.linear2.weight, gain=self.beta)
init.constant_(self.linear1.bias, 0.)
init.constant_(self.linear2.bias, 0.)
# GLU Variants Improve Transformer
# https://arxiv.org/pdf/2002.05202v1.pdf
class SwiGLUFeedforwardLayer(nn.Module):
def __init__(
self,
d_model,
d_feedforward,
beta=1.0,
dropout=0.1
):
super().__init__()
self.d_model = d_model
self.d_feedforward = d_feedforward
self.beta = 1.0
self.linear1 = nn.Linear(self.d_model, self.d_feedforward * 2, bias=False)
self.linear2 = nn.Linear(self.d_feedforward, self.d_model, bias=False)
self.dropout = nn.Dropout(dropout)
self.reset_parameters()
def forward(self, x):
x, gate = self.linear1(x).chunk(2, dim=-1)
x = x * F.silu(gate)
x = self.dropout(x)
x = self.linear2(x)
return x
def reset_parameters(self):
# Deepnet initialization
# https://arxiv.org/pdf/2203.00555.pdf
w, g = self.linear1.weight.chunk(2, dim=0)
init.xavier_uniform_(w, gain=self.beta)
init.xavier_uniform_(g, gain=self.beta)
init.xavier_uniform_(self.linear2.weight, gain=self.beta)
class CausalSelfAttention(nn.Module):
def __init__(
self,
d_model,
num_heads,
# values:
# native: Use local impementation; slowest option; good for debugging; useful when experimenting with non-standard stuff.
# torch: Use pytorch "scaled_dot_product_attention()"; faster; generally good compatibility; does not support returning attn weights.
# flash2: Use Flash-Attention2 implementation; fastest; limited to int16 and bfloat16 types; least memory usage.
attn_type,
beta=1.0,
dropout=0.1,
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.beta = beta
self.attn_type = attn_type
assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads"
# The dimension of each head.
self.d_head = d_model // num_heads
# We scale the attention scores by the inverse-square-root of the head dimension
# this shifts the temerature of softmax.
self.dot_product_scale = 1.0 / math.sqrt(self.d_head)
self.in_proj = nn.Linear(self.d_model, 3 * self.d_model, bias=True)
self.output_linear = nn.Linear(self.d_model, self.d_model, bias=True)
self.dropout = nn.Dropout(dropout)
self.reset_parameters()
def extra_repr(self) -> str:
return f'd_model={self.d_model}, num_heads={self.num_heads}, beta={self.beta}, attn_type={self.attn_type}, dropout={self.dropout}'
def reset_parameters(self):
# Deepnet initialization
# https://arxiv.org/pdf/2203.00555.pdf
q, k, v = self.in_proj.weight.chunk(3)
init.xavier_uniform_(q, gain=1.0)
init.xavier_uniform_(k, gain=1.0)
init.xavier_uniform_(v, gain=self.beta)
init.xavier_uniform_(self.output_linear.weight, gain=self.beta)
init.constant_(self.in_proj.bias, 0.)
init.constant_(self.output_linear.bias, 0.)
def project_input(self, qkv):
proj = self.in_proj(qkv)
return proj.chunk(chunks=3, dim=-1)
def forward(self, qkv, need_weights):
if self.attn_type == "flash2":
return self.flash2_forward(qkv)
# qkv: (batch_size, seq_len, d_embed)
batch_size, seq_len, d_embed = qkv.shape
# Feed the inputs through the K, Q, V matrices.
query, key, value = self.project_input(qkv)
# Split projections into multiple heads and swap position of sequence / heads dimension
query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
# Default to returning empty attention weights.
attention_weights = None
if self.attn_type == "torch":
# This context manager can be used to force which implementation to use.
#with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
attended_values = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=True,
scale=self.dot_product_scale
)
# "native" scaled-dot-product attention implementation.
else:
# Compute attention scores
scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale
# Mask future positions from the past
scores.masked_fill_(
torch.tril(
torch.ones(seq_len, seq_len, dtype=torch.bool, device=qkv.device),
diagonal=0,
).logical_not(),
float('-inf'),
)
# Calculate the attention weights; avoid NANs that might emerge from zeros in softmax's denominator
attention_weights = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10))
del scores
# Use the attention weights to get a weighted combination of value vectors
attended_values = torch.matmul(attention_weights, value)
if not need_weights:
del attention_weights
attention_weights = None
# Concatenate attention heads and project to original embedding size using the output linear layer
attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed)
# Project the concatenated output through the output matrix.
attended_values = self.output_linear(attended_values)
return attended_values, attention_weights
def flash2_forward(self, qkv):
batch_size, seq_len, d_embed = qkv.shape
# Feed the inputs through the K, Q, V matrices.
# query : (batch_size, seq_len, d_model)
# qkv : (batch_size, seq_len, 3, num_heads, d_kq)
qkv = self.in_proj(qkv).unflatten(
-1,
(3, self.num_heads, self.d_head)
)
attended_values = flash_attn_qkvpacked_func(
qkv.bfloat16(),
dropout_p=self.dropout.p if self.training else 0.0,
softmax_scale=self.dot_product_scale,
causal=True,
)
# attended_values: (batch_size, seqlen, nheads, headdim)
# Concatentate heads back into d_embed
attended_values = attended_values.view(batch_size, seq_len, d_embed)
# Project the concatenated output through the output matrix.
attended_values = self.output_linear(attended_values)
return attended_values, None
# Attention layer with ALiBi relative positional encoding
# TRAIN SHORT, TEST LONG: ATTENTION WITH LINEAR BIASES ENABLES INPUT LENGTH EXTRAPOLATION
# https://arxiv.org/pdf/2108.12409.pdf
def alibi_biases(query_len, key_len, device='cpu'):
x = torch.arange(key_len, device=device)[None, :]
y = torch.arange(query_len, device=device)[:, None]
return x - y
class CausalAlibiAttention(nn.Module):
def __init__(
self,
d_model,
num_heads,
beta=1.0,
dropout=0.1,
# values:
# native: Use local impementation; slowest option; good for debugging; useful when experimenting with non-standard stuff.
# torch: Use pytorch "scaled_dot_product_attention()"; faster; generally good compatibility; does not support returning attn weights.
# flash2: Use Flash-Attention2 implementation; fastest; limited to int16 and bfloat16 types; can't train Alibi weights; least memory usage.
# Note: You can perform initial training with "torch," then switch to "flash2," after the Alibi weights have settled.
window_size=None,
attn_type="native",
freeze_alibi=True,
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.beta = beta
self.attn_type = attn_type
assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads"
# The dimension of each head.
self.d_head = d_model // num_heads
# We scale the attention scores by the inverse-square-root of the head dimension
# this shifts the temerature of softmax.
self.dot_product_scale = 1.0 / math.sqrt(self.d_head)
self.in_proj = nn.Parameter(torch.empty(3 * self.d_model, self.d_model))
self.output_linear = nn.Linear(self.d_model, self.d_model, bias=False)
if window_size is not None:
self.window_size=(window_size, -1)
else:
self.window_size = (-1, -1)
self.dropout = nn.Dropout(dropout)
# This generates the original slope distribution from the paper.
# Observations with trainable slopes suggest that the high half of the slopes shift
# towards / past 1.0 and the low half approach zero or even go slightly negative.
# alibi_slopes = 1.0 / torch.logspace(1, 8, self.num_heads, base=2, dtype=torch.float)
# These appear to work better, as initial values, in practice.
alibi_slopes = 1.0 / torch.logspace(0, 7, self.num_heads, base=2, dtype=torch.float)
# If not trainable, it can improve performance somewhat if the low half are set to zero. Apparently
# making roughly half of the slopes position-agnostic is somehow closer to optimal?
# alibi_slopes.masked_fill_(torch.where(torch.arange(0, self.num_heads) >= (self.num_heads / 2), True, False), 0)
self.alibi_slopes = nn.Parameter(alibi_slopes)
# Optionally, allow/disallow training of ALiBi slopes.
self.alibi_slopes.requires_grad = (not freeze_alibi)
self.reset_parameters()
def extra_repr(self) -> str:
return f'd_model={self.d_model}, num_heads={self.num_heads}, beta={self.beta}, attn_type={self.attn_type}, window_size={self.window_size}, dropout={self.dropout}'
def reset_parameters(self):
# Deepnet initialization
# https://arxiv.org/pdf/2203.00555.pdf
q, k, v = self.in_proj.chunk(3)
init.xavier_uniform_(q, gain=1.0)
init.xavier_uniform_(k, gain=1.0)
init.xavier_uniform_(v, gain=self.beta)
init.xavier_uniform_(self.output_linear.weight, gain=self.beta)
def project_input(self, qkv):
proj = F.linear(qkv, self.in_proj)
return proj.chunk(chunks=3, dim=-1)
def forward(self, qkv, need_weights):
if self.attn_type == "flash2":
return self.flash2_forward(qkv)
# qkv: (batch_size, seq_len, d_embed)
batch_size, seq_len, d_embed = qkv.shape
# Feed the inputs through the K, Q, V matrices.
query, key, value = self.project_input(qkv)
# Split projections into multiple heads and swap position of sequence / heads dimension
query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
# Apply Alibi relative positional biases.
attn_bias = alibi_biases(seq_len, seq_len, device=query.device) * self.alibi_slopes.view(-1, 1, 1)
# Mask future positions from the past
causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=qkv.device), diagonal=0)
attn_bias.masked_fill_(causal_mask.logical_not(), float('-inf'))
del causal_mask
# Default to returning empty attention weights.
attention_weights = None
if self.attn_type == "torch":
# This context manager can be used to force which implementation to use.
#with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
attended_values = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attn_bias.to(dtype=query.dtype),
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=False,
scale=self.dot_product_scale
)
# "native" scaled-dot-product attention implementation.
else:
# Compute attention scores
scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale
# Adjust scores with attn_mask
scores += attn_bias
# Calculate the attention weights; avoid NANs that might emerge from zeros in softmax's denominator
attention_weights = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10))
# Use the attention weights to get a weighted combination of value vectors
attended_values = torch.matmul(attention_weights, value)
if not need_weights:
attention_weights = None
# Concatenate attention heads and project to original embedding size using the output linear layer
attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed)
# Project the concatenated output through the output matrix.
attended_values = self.output_linear(attended_values)
return attended_values, attention_weights
def flash2_forward(self, qkv):
batch_size, seq_len, d_embed = qkv.shape
# Feed the inputs through the K, Q, V matrices.
# query : (batch_size, seq_len, d_model)
# qkv : (batch_size, seq_len, 3, num_heads, d_kq)
qkv = F.linear(
qkv,
self.in_proj,
).unflatten(
-1,
(3, self.num_heads, self.d_head)
)
attended_values = flash_attn_qkvpacked_func(
qkv.bfloat16(),
dropout_p=self.dropout.p if self.training else 0.0,
softmax_scale=self.dot_product_scale,
causal=True,
window_size=self.window_size,
alibi_slopes=self.alibi_slopes.float(),
).to(dtype=qkv.dtype)
# attended_values: (batch_size, seqlen, nheads, headdim)
# Concatentate heads back into d_embed
attended_values = attended_values.view(batch_size, seq_len, d_embed)
# Project the concatenated output through the output matrix.
attended_values = self.output_linear(attended_values)
return attended_values, None