Plonk / models /networks /transformers.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
import torch
import torch.nn as nn
from torch import Tensor
import math
from models.positional_embeddings import PositionalEmbedding, FourierEmbedding
from einops import rearrange
torch.fx.wrap("rearrange")
from typing import Tuple, Optional
from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1
allow_ops_in_compiled_graph()
class FusedMLP(nn.Sequential):
def __init__(
self,
dim_model: int,
dropout: float,
activation: nn.Module,
hidden_layer_multiplier: int = 4,
bias: bool = True,
):
super().__init__(
nn.Linear(dim_model, dim_model * hidden_layer_multiplier, bias=bias),
activation(),
nn.Dropout(dropout),
nn.Linear(dim_model * hidden_layer_multiplier, dim_model, bias=bias),
)
def _cast_if_autocast_enabled(tensor):
if torch.is_autocast_enabled():
if tensor.device.type == "cuda":
dtype = torch.get_autocast_gpu_dtype()
elif tensor.device.type == "cpu":
dtype = torch.get_autocast_cpu_dtype()
else:
raise NotImplementedError()
return tensor.to(dtype=dtype)
return tensor
class LayerNorm16Bits(torch.nn.LayerNorm):
"""
16-bit friendly version of torch.nn.LayerNorm
"""
def __init__(
self,
normalized_shape,
eps=1e-06,
elementwise_affine=True,
device=None,
dtype=None,
):
super().__init__(
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
device=device,
dtype=dtype,
)
def forward(self, x):
module_device = x.device
downcast_x = _cast_if_autocast_enabled(x)
downcast_weight = (
_cast_if_autocast_enabled(self.weight)
if self.weight is not None
else self.weight
)
downcast_bias = (
_cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
)
with torch.autocast(enabled=False, device_type=module_device.type):
return nn.functional.layer_norm(
downcast_x,
self.normalized_shape,
downcast_weight,
downcast_bias,
self.eps,
)
class StochatichDepth(nn.Module):
def __init__(self, p: float):
super().__init__()
self.survival_prob = 1.0 - p
def forward(self, x: Tensor) -> Tensor:
if self.training and self.survival_prob < 1:
mask = (
torch.empty(x.shape[0], 1, 1, device=x.device).uniform_()
+ self.survival_prob
)
mask = mask.floor()
if self.survival_prob > 0:
mask = mask / self.survival_prob
return x * mask
else:
return x
class CrossAttentionOp(nn.Module):
def __init__(
self, attention_dim, num_heads, dim_q, dim_kv, use_biases=True, is_sa=False
):
super().__init__()
self.dim_q = dim_q
self.dim_kv = dim_kv
self.attention_dim = attention_dim
self.num_heads = num_heads
self.use_biases = use_biases
self.is_sa = is_sa
if self.is_sa:
self.qkv = nn.Linear(dim_q, attention_dim * 3, bias=use_biases)
else:
self.q = nn.Linear(dim_q, attention_dim, bias=use_biases)
self.kv = nn.Linear(dim_kv, attention_dim * 2, bias=use_biases)
self.out = nn.Linear(attention_dim, dim_q, bias=use_biases)
def forward(self, x_to, x_from=None, attention_mask=None, materialize_sdpa=False):
if x_from is None:
x_from = x_to
if self.is_sa:
q, k, v = self.qkv(x_to).chunk(3, dim=-1)
else:
q = self.q(x_to)
k, v = self.kv(x_from).chunk(2, dim=-1)
q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads)
v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads)
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1)
if materialize_sdpa:
x = self.materialize_sdpa(q, k, v, attention_mask)
else:
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask
)
x = rearrange(x, "b h n d -> b n (h d)")
x = self.out(x)
return x
def materialize_sdpa(self, q, k, v, attn_mask=None):
scale = 1.0 / math.sqrt(q.shape[-1])
attn_matrix = torch.einsum("b h i d, b h j d -> b h i j", q, k) * scale
if attn_mask is not None:
attn_matrix = attn_matrix * attn_mask
attn_matrix = torch.nn.functional.softmax(attn_matrix, dim=-1)
return torch.einsum("b h i j, b h j d -> b h i d", attn_matrix, v)
class CrossAttentionBlock(nn.Module):
def __init__(
self,
dim_q: int,
dim_kv: int,
num_heads: int,
attention_dim: int = 0,
mlp_multiplier: int = 4,
dropout: float = 0.0,
stochastic_depth: float = 0.0,
use_biases: bool = True,
retrieve_attention_scores: bool = False,
use_16_bits_layer_norm: bool = False,
):
super().__init__()
if use_16_bits_layer_norm and not retrieve_attention_scores:
LayerNorm = LayerNorm16Bits
else:
LayerNorm = nn.LayerNorm
self.retrieve_attention_scores = retrieve_attention_scores
self.initial_to_ln = LayerNorm(dim_q, eps=1e-6)
attention_dim = min(dim_q, dim_kv) if attention_dim == 0 else attention_dim
self.ca = CrossAttentionOp(
attention_dim, num_heads, dim_q, dim_kv, is_sa=False, use_biases=use_biases
)
self.ca_stochastic_depth = StochatichDepth(stochastic_depth)
self.middle_ln = LayerNorm(dim_q, eps=1e-6)
self.ffn = FusedMLP(
dim_model=dim_q,
dropout=dropout,
activation=nn.GELU,
hidden_layer_multiplier=mlp_multiplier,
bias=use_biases,
)
self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
self.register_parameter(
"attention_mask_dummy",
nn.Parameter(torch.ones(1, 1, dtype=torch.bool), requires_grad=False),
)
def forward(
self,
to_tokens: Tensor,
from_tokens: Tensor,
to_token_mask: Optional[Tensor] = None,
from_token_mask: Optional[Tensor] = None,
) -> Tensor:
if to_token_mask is None and from_token_mask is None:
attention_mask = None
else:
if to_token_mask is None:
to_token_mask = self.attention_mask_dummy.expand(
to_tokens.shape[0],
to_tokens.shape[1],
)
if from_token_mask is None:
from_token_mask = self.attention_mask_dummy.expand(
from_tokens.shape[0],
from_tokens.shape[1],
)
attention_mask = from_token_mask.unsqueeze(1) * to_token_mask.unsqueeze(2)
if self.retrieve_attention_scores:
attention_output = self.ca(
self.initial_to_ln(to_tokens),
from_tokens,
attention_mask=attention_mask,
materialize_sdpa=True,
)
else:
attention_output = self.ca(
self.initial_to_ln(to_tokens),
from_tokens,
attention_mask=attention_mask,
)
to_tokens = to_tokens + self.ca_stochastic_depth(attention_output)
to_tokens = to_tokens + self.ffn_stochastic_depth(
self.ffn(self.middle_ln(to_tokens))
)
return to_tokens
class SelfAttentionBlock(nn.Module):
def __init__(
self,
dim_qkv: int,
num_heads: int,
attention_dim: int = 0,
mlp_multiplier: int = 4,
dropout: float = 0.0,
stochastic_depth: float = 0.0,
use_biases: bool = True,
use_layer_scale: bool = False,
layer_scale_value: float = 0.1,
retrieve_attention_scores: bool = False,
use_16_bits_layer_norm: bool = False,
):
super().__init__()
if use_16_bits_layer_norm and not retrieve_attention_scores:
LayerNorm = LayerNorm16Bits
else:
LayerNorm = nn.LayerNorm
self.retrieve_attention_scores = retrieve_attention_scores
self.initial_ln = LayerNorm(dim_qkv, eps=1e-6)
attention_dim = dim_qkv if attention_dim == 0 else attention_dim
self.sa = CrossAttentionOp(
attention_dim,
num_heads,
dim_qkv,
dim_qkv,
is_sa=True,
use_biases=use_biases,
)
self.sa_stochastic_depth = StochatichDepth(stochastic_depth)
self.middle_ln = LayerNorm(dim_qkv, eps=1e-6)
self.ffn = FusedMLP(
dim_model=dim_qkv,
dropout=dropout,
activation=nn.GELU,
hidden_layer_multiplier=mlp_multiplier,
bias=use_biases,
)
self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
)
self.layer_scale_2 = nn.Parameter(
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
)
self.register_parameter(
"attention_mask_dummy",
nn.Parameter(torch.ones(1, 1, dtype=torch.bool), requires_grad=False),
)
def forward(
self,
tokens: torch.Tensor,
token_mask: Optional[torch.Tensor] = None,
):
if token_mask is None:
attention_mask = None
else:
attention_mask = token_mask.unsqueeze(1) * self.attention_mask_dummy.expand(
tokens.shape[0],
tokens.shape[1],
).unsqueeze(2)
if self.retrieve_attention_scores:
attention_output = self.sa(
self.initial_ln(tokens),
attention_mask=attention_mask,
materialize_sdpa=True,
)
else:
attention_output = self.sa(
self.initial_ln(tokens),
attention_mask=attention_mask,
)
if self.use_layer_scale:
tokens = tokens + self.sa_stochastic_depth(
self.layer_scale_1 * attention_output
)
tokens = tokens + self.ffn_stochastic_depth(
self.layer_scale_2 * self.ffn(self.middle_ln(tokens))
)
else:
tokens = tokens + self.sa_stochastic_depth(attention_output)
tokens = tokens + self.ffn_stochastic_depth(
self.ffn(self.middle_ln(tokens))
)
return tokens