|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Transformer model, with streaming support, xformer attention support |
|
and easy causal attention with a potentially finite receptive field. |
|
|
|
See `StreamingTransformer` for more information. |
|
|
|
Unlike regular PyTorch Transformer, we make the hard choice that batches are first. |
|
""" |
|
|
|
import typing as tp |
|
|
|
from einops import rearrange |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
from torch.utils.checkpoint import checkpoint as torch_checkpoint |
|
from xformers import ops |
|
|
|
from .rope import RotaryEmbedding |
|
from .streaming import StreamingModule |
|
|
|
|
|
def _is_profiled() -> bool: |
|
|
|
try: |
|
from xformers.profiler import profiler |
|
except ImportError: |
|
return False |
|
return profiler._Profiler._CURRENT_PROFILER is not None |
|
|
|
|
|
def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: |
|
"""Create normalization module for transformer encoder layer. |
|
|
|
Args: |
|
norm_type (str): Normalization method. |
|
dim (int): Dimension of the normalized layer. |
|
**kwargs (dict): Additional parameters for normalization layer. |
|
Returns: |
|
nn.Module: Normalization module. |
|
""" |
|
if norm_type == 'layer_norm': |
|
return nn.LayerNorm(dim, eps=1e-5, **kwargs) |
|
else: |
|
raise ValueError(f"Unknown norm type: {norm_type}") |
|
|
|
|
|
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000, |
|
dtype: torch.dtype = torch.float32) -> torch.Tensor: |
|
"""Create sinusoidal positional embedding, with shape `[B, T, C]`. |
|
|
|
Args: |
|
positions (torch.Tensor): LongTensor of positions. |
|
dim (int): Dimension of the embedding. |
|
max_period (float): Maximum period of the cosine/sine functions. |
|
dtype (torch.dtype or str): dtype to use to generate the embedding. |
|
Returns: |
|
torch.Tensor: Sinusoidal positional embedding. |
|
""" |
|
|
|
assert dim % 2 == 0 |
|
half_dim = dim // 2 |
|
positions = positions.to(dtype) |
|
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1) |
|
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) |
|
phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) |
|
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) |
|
|
|
|
|
def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers""" |
|
bs, slen, n_kv_heads, head_dim = x.shape |
|
if n_rep == 1: |
|
return x |
|
return ( |
|
x[:, :, :, None, :] |
|
.expand(bs, slen, n_kv_heads, n_rep, head_dim) |
|
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) |
|
) |
|
|
|
|
|
class LayerScale(nn.Module): |
|
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). |
|
This rescales diagonaly the residual outputs close to 0, with a learnt scale. |
|
|
|
Args: |
|
channels (int): Number of channels. |
|
init (float): Initial scale. |
|
channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`. |
|
device (torch.device or None): Device on which to initialize the module. |
|
dtype (torch.dtype or None): dtype to use to initialize the module. |
|
""" |
|
def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True, |
|
device=None, dtype=None): |
|
super().__init__() |
|
self.channel_last = channel_last |
|
self.scale = nn.Parameter( |
|
torch.full((channels,), init, |
|
requires_grad=True, device=device, dtype=dtype)) |
|
|
|
def forward(self, x: torch.Tensor): |
|
if self.channel_last: |
|
return self.scale * x |
|
else: |
|
return self.scale[:, None] * x |
|
|
|
|
|
class StreamingMultiheadAttention(StreamingModule): |
|
"""Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation. |
|
|
|
Args: |
|
embed_dim (int): Dimension to project to. |
|
num_heads (int): Number of heads. |
|
dropout (float): Dropout level. |
|
bias (bool): Use bias in projections. |
|
causal (bool): Causal mask applied automatically. |
|
past_context (int or None): Receptive field for the causal mask, infinite if None. |
|
custom (bool): Use custom MHA implementation, for testing / benchmarking. |
|
memory_efficient (bool): Use xformers based memory efficient attention. |
|
attention_as_float32 (bool): Perform the attention as float32 |
|
(especially important with memory_efficient as autocast won't do this automatically). |
|
rope (`RotaryEmbedding` or None): Rope embedding to use. |
|
cross_attention: Should be true when used as a cross attention. |
|
All keys and values must be available at once, streaming is only for the queries. |
|
Cannot be used with `causal` or `rope` (as it wouldn't make sens to |
|
intepret the time steps in the keys relative to those in the queries). |
|
safe_streaming (bool): Bug fix, will go away with xformers update. |
|
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product. |
|
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). |
|
This will lead to faster decoding time on A100 or other GPUs with tensorcore. |
|
device (torch.device or None): Sevice on which to initialize. |
|
dtype (torch.dtype or None): dtype to use. |
|
""" |
|
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, |
|
causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False, |
|
memory_efficient: bool = False, attention_as_float32: bool = False, |
|
rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False, |
|
safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1, |
|
device=None, dtype=None): |
|
super().__init__() |
|
factory_kwargs = {'device': device, 'dtype': dtype} |
|
if past_context is not None: |
|
assert causal |
|
|
|
self.embed_dim = embed_dim |
|
self.causal = causal |
|
self.past_context = past_context |
|
self.memory_efficient = memory_efficient |
|
self.attention_as_float32 = attention_as_float32 |
|
self.rope = rope |
|
self.cross_attention = cross_attention |
|
self.safe_streaming = safe_streaming |
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
self.kv_repeat = kv_repeat |
|
if cross_attention: |
|
assert not causal, "Causal cannot work with cross attention." |
|
assert rope is None, "Rope cannot work with cross attention." |
|
|
|
if memory_efficient: |
|
_verify_xformers_memory_efficient_compat() |
|
|
|
self.custom = _is_custom(custom, memory_efficient) |
|
if self.custom: |
|
out_dim = embed_dim |
|
assert num_heads % kv_repeat == 0 |
|
assert not cross_attention or kv_repeat == 1 |
|
num_kv = num_heads // kv_repeat |
|
kv_dim = (embed_dim // num_heads) * num_kv |
|
out_dim += 2 * kv_dim |
|
in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs) |
|
|
|
self.in_proj_weight = in_proj.weight |
|
self.in_proj_bias = in_proj.bias |
|
if bias: |
|
self.in_proj_bias.data.zero_() |
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) |
|
if bias: |
|
self.out_proj.bias.data.zero_() |
|
else: |
|
assert not qk_layer_norm |
|
assert kv_repeat == 1 |
|
self.mha = nn.MultiheadAttention( |
|
embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True, |
|
**factory_kwargs) |
|
self.qk_layer_norm = qk_layer_norm |
|
if qk_layer_norm: |
|
assert self.custom |
|
assert kv_repeat == 1 |
|
ln_dim = embed_dim |
|
self.q_layer_norm = nn.LayerNorm(ln_dim) |
|
self.k_layer_norm = nn.LayerNorm(ln_dim) |
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): |
|
if not self.custom: |
|
|
|
keys = [n for n, _ in self.mha.named_parameters()] |
|
for key in keys: |
|
if prefix + key in state_dict: |
|
state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key) |
|
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype): |
|
|
|
|
|
|
|
if self.memory_efficient: |
|
from xformers.ops import LowerTriangularMask |
|
if current_steps == 1: |
|
|
|
return None |
|
elif 'past_keys' in self._streaming_state: |
|
raise RuntimeError('Not supported at the moment') |
|
else: |
|
|
|
return LowerTriangularMask() |
|
if self._streaming_state: |
|
past_keys = self._streaming_state['past_keys'] |
|
past_steps = past_keys.shape[1] |
|
else: |
|
past_steps = 0 |
|
|
|
queries_pos = torch.arange( |
|
past_steps, current_steps + past_steps, device=device).view(-1, 1) |
|
keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1) |
|
delta = queries_pos - keys_pos |
|
valid = delta >= 0 |
|
if self.past_context is not None: |
|
valid &= (delta <= self.past_context) |
|
return torch.where( |
|
valid, |
|
torch.zeros([], device=device, dtype=dtype), |
|
torch.full([], float('-inf'), device=device, dtype=dtype)) |
|
|
|
def _complete_kv(self, k, v): |
|
if self.cross_attention: |
|
|
|
|
|
|
|
return k, v |
|
|
|
if self._streaming_state: |
|
pk = self._streaming_state['past_keys'] |
|
nk = torch.cat([pk, k], dim=1) |
|
if v is k: |
|
nv = nk |
|
else: |
|
pv = self._streaming_state['past_values'] |
|
nv = torch.cat([pv, v], dim=1) |
|
else: |
|
nk = k |
|
nv = v |
|
|
|
assert nk.shape[1] == nv.shape[1] |
|
offset = 0 |
|
if self.past_context is not None: |
|
offset = max(0, nk.shape[1] - self.past_context) |
|
if self._is_streaming: |
|
self._streaming_state['past_keys'] = nk[:, offset:] |
|
if v is not k: |
|
self._streaming_state['past_values'] = nv[:, offset:] |
|
if 'offset' in self._streaming_state: |
|
self._streaming_state['offset'] += offset |
|
else: |
|
self._streaming_state['offset'] = torch.tensor(0) |
|
return nk, nv |
|
|
|
def _apply_rope(self, query: torch.Tensor, key: torch.Tensor): |
|
|
|
assert self.rope is not None |
|
if 'past_keys' in self._streaming_state: |
|
past_keys_offset = self._streaming_state['past_keys'].shape[1] |
|
else: |
|
past_keys_offset = 0 |
|
if 'offset' in self._streaming_state: |
|
past_context_offset = int(self._streaming_state['offset'].item()) |
|
else: |
|
past_context_offset = 0 |
|
streaming_offset = past_context_offset + past_keys_offset |
|
return self.rope.rotate_qk(query, key, start=streaming_offset) |
|
|
|
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, |
|
key_padding_mask=None, need_weights=False, attn_mask=None, |
|
average_attn_weights=True, is_causal=False): |
|
assert attn_mask is None |
|
assert not is_causal, ("new param added in torch 2.0.1 not supported, " |
|
"use the causal args in the constructor.") |
|
|
|
dtype = query.dtype |
|
if self._is_streaming: |
|
assert self.causal or self.cross_attention, \ |
|
"Streaming only available for causal or cross attention" |
|
|
|
if self.causal: |
|
|
|
assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value" |
|
assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value" |
|
attn_mask = self._get_mask(query.shape[1], query.device, query.dtype) |
|
|
|
if self.custom: |
|
|
|
assert need_weights is False |
|
assert key_padding_mask is None |
|
if self.cross_attention: |
|
|
|
|
|
dim = self.in_proj_weight.shape[0] // 3 |
|
if self.in_proj_bias is None: |
|
bias_q, bias_k, bias_v = None, None, None |
|
else: |
|
bias_q = self.in_proj_bias[:dim] |
|
bias_k = self.in_proj_bias[dim: 2 * dim] |
|
bias_v = self.in_proj_bias[2 * dim:] |
|
q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q) |
|
|
|
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k) |
|
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v) |
|
if self.qk_layer_norm is True: |
|
q = self.q_layer_norm(q) |
|
k = self.k_layer_norm(k) |
|
|
|
q, k, v = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k, v]] |
|
else: |
|
if not _is_profiled(): |
|
|
|
assert query is key, "specialized implementation" |
|
assert value is key, "specialized implementation" |
|
projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias) |
|
if self.kv_repeat == 1: |
|
packed = rearrange(projected, "b t (p h d) -> b t p h d", p=3, h=self.num_heads) |
|
q, k, v = ops.unbind(packed, dim=2) |
|
else: |
|
embed_dim = self.embed_dim |
|
per_head_dim = (embed_dim // self.num_heads) |
|
kv_heads = self.num_heads // self.kv_repeat |
|
q = projected[:, :, :embed_dim] |
|
start = embed_dim |
|
end = start + per_head_dim * kv_heads |
|
k = projected[:, :, start: end] |
|
v = projected[:, :, end:] |
|
q = rearrange(q, "b t (h d) -> b t h d", h=self.num_heads) |
|
k = rearrange(k, "b t (h d) -> b t h d", h=kv_heads) |
|
v = rearrange(v, "b t (h d) -> b t h d", h=kv_heads) |
|
|
|
if self.qk_layer_norm is True: |
|
assert self.kv_repeat == 1 |
|
q, k = [rearrange(x, "b t h d -> b t (h d)") for x in [q, k]] |
|
q = self.q_layer_norm(q) |
|
k = self.k_layer_norm(k) |
|
q, k = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k]] |
|
if self.rope: |
|
q, k = self._apply_rope(q, k) |
|
k, v = self._complete_kv(k, v) |
|
if self.kv_repeat > 1: |
|
k = expand_repeated_kv(k, self.kv_repeat) |
|
v = expand_repeated_kv(v, self.kv_repeat) |
|
if self.attention_as_float32: |
|
q, k, v = [x.float() for x in [q, k, v]] |
|
if self.memory_efficient: |
|
p = self.dropout if self.training else 0 |
|
x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p) |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q = q / q.shape[-1] ** 0.5 |
|
if self._is_streaming and self.safe_streaming and q.device.type == 'cuda': |
|
with torch.autocast(device_type=q.device.type, dtype=torch.float32): |
|
pre_w = torch.einsum("bqhc,bkhc->bhqk", q, k) |
|
else: |
|
pre_w = torch.einsum("bqhc,bkhc->bhqk", q, k) |
|
if attn_mask is not None: |
|
pre_w = pre_w + attn_mask |
|
w = torch.softmax(pre_w, dim=-1) |
|
w = F.dropout(w, self.dropout, training=self.training).to(v) |
|
x = torch.einsum("bhqk,bkhc->bqhc", w, v) |
|
x = x.to(dtype) |
|
x = rearrange(x, "b t h d -> b t (h d)", h=self.num_heads) |
|
x = self.out_proj(x) |
|
else: |
|
key, value = self._complete_kv(key, value) |
|
if self.attention_as_float32: |
|
query, key, value = [x.float() for x in [query, key, value]] |
|
x, _ = self.mha( |
|
query, key, value, key_padding_mask, |
|
need_weights, attn_mask, average_attn_weights) |
|
x = x.to(dtype) |
|
|
|
return x, None |
|
|
|
|
|
class StreamingTransformerLayer(nn.TransformerEncoderLayer): |
|
"""TransformerLayer with Streaming / Causal support. |
|
This also integrates cross_attention, when passing `cross_attention=True`, |
|
rather than having two separate classes like in PyTorch. |
|
|
|
Args: |
|
d_model (int): Dimension of the data. |
|
num_heads (int): Number of heads. |
|
dim_feedforward (int): Intermediate dimension of FF module. |
|
dropout (float): Dropout both for MHA and FF. |
|
bias_ff (bool): Use bias for FF. |
|
bias_attn (bool): Use bias for MHA. |
|
causal (bool): Causal mask applied automatically. |
|
past_context (int or None): Receptive field for the causal mask, infinite if None. |
|
custom (bool): Use custom MHA implementation, for testing / benchmarking. |
|
memory_efficient (bool): Use xformers based memory efficient attention. |
|
attention_as_float32 (bool): Perform the attention as float32 |
|
(especially important with memory_efficient as autocast won't do this automatically). |
|
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention. |
|
qk_layer_norm_cross (bool): Same for the cross attention. |
|
cross_attention (bool): If True, expect to get secondary input for cross-attention. |
|
Cross attention will use the default MHA, as it typically won't require |
|
special treatment. |
|
layer_scale (float or None): If not None, LayerScale will be used with |
|
the given value as initial scale. |
|
rope (`RotaryEmbedding` or None): Rope embedding to use. |
|
attention_dropout (float or None): If not None, separate the value of the dimension dropout |
|
in FFN and of the attention dropout. |
|
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). |
|
This will lead to faster decoding time on A100 or other GPUs with tensorcore. |
|
device (torch.device or None): Device on which to initialize. |
|
dtype (torch.dtype or None): dtype to use. |
|
**kwargs: See `nn.TransformerEncoderLayer`. |
|
""" |
|
def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1, |
|
bias_ff: bool = True, bias_attn: bool = True, causal: bool = False, |
|
past_context: tp.Optional[int] = None, custom: bool = False, |
|
memory_efficient: bool = False, attention_as_float32: bool = False, |
|
qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False, |
|
cross_attention: bool = False, layer_scale: tp.Optional[float] = None, |
|
rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None, |
|
kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs): |
|
super().__init__(d_model, num_heads, dim_feedforward, dropout, |
|
device=device, dtype=dtype, batch_first=True, **kwargs) |
|
factory_kwargs = {'device': device, 'dtype': dtype} |
|
|
|
attn_kwargs: tp.Dict[str, tp.Any] = { |
|
'embed_dim': d_model, |
|
'num_heads': num_heads, |
|
'dropout': dropout if attention_dropout is None else attention_dropout, |
|
'bias': bias_attn, |
|
'custom': custom, |
|
'memory_efficient': memory_efficient, |
|
'attention_as_float32': attention_as_float32, |
|
} |
|
self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention( |
|
causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm, |
|
kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) |
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs) |
|
|
|
self.layer_scale_1: nn.Module |
|
self.layer_scale_2: nn.Module |
|
if layer_scale is None: |
|
self.layer_scale_1 = nn.Identity() |
|
self.layer_scale_2 = nn.Identity() |
|
else: |
|
self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) |
|
self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) |
|
|
|
self.cross_attention: tp.Optional[nn.Module] = None |
|
if cross_attention: |
|
self.cross_attention = StreamingMultiheadAttention( |
|
cross_attention=True, qk_layer_norm=qk_layer_norm_cross, |
|
**attn_kwargs, **factory_kwargs) |
|
|
|
self.dropout_cross = nn.Dropout(dropout) |
|
|
|
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs) |
|
self.layer_scale_cross: nn.Module |
|
if layer_scale is None: |
|
self.layer_scale_cross = nn.Identity() |
|
else: |
|
self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs) |
|
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) |
|
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) |
|
|
|
def _cross_attention_block(self, src: torch.Tensor, |
|
cross_attention_src: torch.Tensor) -> torch.Tensor: |
|
assert self.cross_attention is not None |
|
|
|
x = self.cross_attention( |
|
src, cross_attention_src, cross_attention_src, need_weights=False)[0] |
|
return self.dropout_cross(x) |
|
|
|
def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, |
|
src_key_padding_mask: tp.Optional[torch.Tensor] = None, |
|
cross_attention_src: tp.Optional[torch.Tensor] = None): |
|
if self.cross_attention is None: |
|
assert cross_attention_src is None |
|
else: |
|
assert cross_attention_src is not None |
|
x = src |
|
if self.norm_first: |
|
x = x + self.layer_scale_1( |
|
self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)) |
|
if cross_attention_src is not None: |
|
x = x + self.layer_scale_cross( |
|
self._cross_attention_block( |
|
self.norm_cross(x), cross_attention_src)) |
|
x = x + self.layer_scale_2(self._ff_block(self.norm2(x))) |
|
else: |
|
x = self.norm1(x + self.layer_scale_1( |
|
self._sa_block(x, src_mask, src_key_padding_mask))) |
|
if cross_attention_src is not None: |
|
x = self.norm_cross( |
|
x + self.layer_scale_cross( |
|
self._cross_attention_block(src, cross_attention_src))) |
|
x = self.norm2(x + self.layer_scale_2(self._ff_block(x))) |
|
return x |
|
|
|
|
|
class StreamingTransformer(StreamingModule): |
|
"""Transformer with Streaming / Causal support. |
|
|
|
Args: |
|
d_model (int): Dimension of the data. |
|
num_heads (int): Number of heads. |
|
dim_feedforward (int): Intermediate dimension of FF module. |
|
dropout (float): Dropout both for MHA and FF. |
|
bias_ff (bool): Use bias for FF. |
|
bias_attn (bool): Use bias for MHA. |
|
causal (bool): Causal mask applied automatically. |
|
past_context (int or None): Receptive field for the causal mask, infinite if None. |
|
custom (bool): Use custom MHA implementation, for testing / benchmarking. |
|
memory_efficient (bool): Use xformers based memory efficient attention. |
|
attention_as_float32 (bool): Perform the attention as float32 |
|
(especially important with memory_efficient as autocast won't do this automatically). |
|
cross_attention (bool): If True, expect to get secondary input for cross-attention. |
|
layer_scale (float or None): If not None, LayerScale will be used |
|
with the given value as initial scale. |
|
positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope). |
|
max_period (float): Maximum period of the time embedding. |
|
positional_scale (float): Scale of positional embedding, set to 0 to deactivate. |
|
xpos (bool): Apply xpos exponential decay to positional embedding (rope only). |
|
lr (float or None): learning rate override through the `make_optim_group` API. |
|
weight_decay (float or None): Weight_decay override through the `make_optim_group` API. |
|
layer_class: (subclass of `StreamingTransformerLayer): class to use |
|
to initialize the layers, allowing further customization outside of Audiocraft. |
|
checkpointing (str): Checkpointing strategy to reduce memory usage. |
|
No checkpointing if set to 'none'. Per layer checkpointing using PyTorch |
|
if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice, |
|
minimal memory usage, but maximal runtime). Finally, `xformers_default` provide |
|
a policy for opting-out some operations of the checkpointing like |
|
linear layers and attention, providing a middle ground between speed and memory. |
|
device (torch.device or None): Device on which to initialize. |
|
dtype (torch.dtype or None): dtype to use. |
|
**kwargs: See `nn.TransformerEncoderLayer`. |
|
""" |
|
def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048, |
|
dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True, |
|
causal: bool = False, past_context: tp.Optional[int] = None, |
|
custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, |
|
cross_attention: bool = False, layer_scale: tp.Optional[float] = None, |
|
positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1., |
|
xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None, |
|
layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer, |
|
checkpointing: str = 'none', device=None, dtype=None, **kwargs): |
|
super().__init__() |
|
assert d_model % num_heads == 0 |
|
|
|
self.positional_embedding = positional_embedding |
|
self.max_period = max_period |
|
self.positional_scale = positional_scale |
|
self.weight_decay = weight_decay |
|
self.lr = lr |
|
|
|
assert positional_embedding in ['sin', 'rope', 'sin_rope'] |
|
self.rope: tp.Optional[RotaryEmbedding] = None |
|
if self.positional_embedding in ['rope', 'sin_rope']: |
|
assert _is_custom(custom, memory_efficient) |
|
self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period, |
|
xpos=xpos, scale=positional_scale, device=device) |
|
|
|
self.checkpointing = checkpointing |
|
|
|
assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm'] |
|
if self.checkpointing.startswith('xformers'): |
|
_verify_xformers_internal_compat() |
|
|
|
self.layers = nn.ModuleList() |
|
for idx in range(num_layers): |
|
self.layers.append( |
|
layer_class( |
|
d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, |
|
dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn, |
|
causal=causal, past_context=past_context, custom=custom, |
|
memory_efficient=memory_efficient, attention_as_float32=attention_as_float32, |
|
cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope, |
|
device=device, dtype=dtype, **kwargs)) |
|
|
|
if self.checkpointing != 'none': |
|
for layer in self.layers: |
|
|
|
|
|
layer._magma_checkpointed = True |
|
assert layer.layer_drop == 0., "Need further checking" |
|
|
|
def _apply_layer(self, layer, *args, **kwargs): |
|
method = self.checkpointing |
|
if method == 'none': |
|
return layer(*args, **kwargs) |
|
elif method == 'torch': |
|
return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs) |
|
elif method.startswith('xformers'): |
|
from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy |
|
if method == 'xformers_default': |
|
|
|
|
|
allow_list = [ |
|
"xformers.efficient_attention_forward_cutlass.default", |
|
"xformers_flash.flash_fwd.default", |
|
"aten.addmm.default", |
|
"aten.mm.default", |
|
] |
|
elif method == 'xformers_mm': |
|
|
|
|
|
allow_list = [ |
|
"aten.addmm.default", |
|
"aten.mm.default", |
|
] |
|
else: |
|
raise ValueError(f"xformers checkpointing xformers policy {method} is not known.") |
|
policy_fn = _get_default_policy(allow_list) |
|
return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs) |
|
else: |
|
raise ValueError(f"Checkpointing method {method} is unknown.") |
|
|
|
def forward(self, x: torch.Tensor, *args, **kwargs): |
|
B, T, C = x.shape |
|
|
|
if 'offsets' in self._streaming_state: |
|
offsets = self._streaming_state['offsets'] |
|
else: |
|
offsets = torch.zeros(B, dtype=torch.long, device=x.device) |
|
|
|
if self.positional_embedding in ['sin', 'sin_rope']: |
|
positions = torch.arange(T, device=x.device).view(1, -1, 1) |
|
positions = positions + offsets.view(-1, 1, 1) |
|
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) |
|
x = x + self.positional_scale * pos_emb |
|
|
|
for layer in self.layers: |
|
x = self._apply_layer(layer, x, *args, **kwargs) |
|
|
|
if self._is_streaming: |
|
self._streaming_state['offsets'] = offsets + T |
|
|
|
return x |
|
|
|
def make_optim_group(self): |
|
group = {"params": list(self.parameters())} |
|
if self.lr is not None: |
|
group["lr"] = self.lr |
|
if self.weight_decay is not None: |
|
group["weight_decay"] = self.weight_decay |
|
return group |
|
|
|
|
|
|
|
|
|
def _verify_xformers_memory_efficient_compat(): |
|
try: |
|
from xformers.ops import memory_efficient_attention, LowerTriangularMask |
|
except ImportError: |
|
raise ImportError( |
|
"xformers is not installed. Please install it and try again.\n" |
|
"To install on AWS and Azure, run \n" |
|
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n" |
|
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n" |
|
"To install on FAIR Cluster, run \n" |
|
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n" |
|
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n") |
|
|
|
|
|
def _verify_xformers_internal_compat(): |
|
try: |
|
from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy |
|
except ImportError: |
|
raise ImportError( |
|
"Francisco's fairinternal xformers is not installed. Please install it and try again.\n" |
|
"To install on AWS and Azure, run \n" |
|
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n" |
|
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n" |
|
"To install on FAIR Cluster, run \n" |
|
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n" |
|
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n") |
|
|
|
|
|
def _is_custom(custom: bool, memory_efficient: bool): |
|
return custom or memory_efficient |
|
|