Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Chameleon License found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass | |
import torch | |
from torch import distributed as dist | |
from torch import nn | |
from torch.nn import functional as F | |
from xformers.ops import RMSNorm, fmha, rope_padded | |
from xformers.ops.fmha.attn_bias import ( | |
BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias, | |
) | |
class ModelArgs: | |
model_parallel_size: int = 1 | |
dim: int = 512 | |
n_layers: int = 8 | |
n_heads: int = 8 | |
n_kv_heads: int | None = None | |
vocab_size: int = -1 | |
ffn_dim_multiplier: float | None = None | |
multiple_of: int = 256 | |
norm_eps: float = 1e-5 | |
rope_theta: float = 10000.0 | |
qk_normalization: bool = False | |
swin_norm: bool = False | |
LayerCache = tuple[torch.Tensor, torch.Tensor] | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
model_parallel_size: int, | |
dim: int, | |
head_dim: int, | |
n_heads: int, | |
n_kv_heads: int, | |
rope_theta: float, | |
qk_normalization: bool = False, | |
): | |
super().__init__() | |
self.model_parallel_size = model_parallel_size | |
self.head_dim = head_dim | |
self.rope_theta = rope_theta | |
self.n_local_heads = n_heads // model_parallel_size | |
self.n_local_kv_heads = n_kv_heads // model_parallel_size | |
self.wqkv = nn.Linear( | |
dim, | |
(self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim, | |
bias=False, | |
dtype=torch.bfloat16, | |
) | |
self.wo = nn.Linear( | |
self.n_local_heads * head_dim, | |
dim, | |
bias=False, | |
dtype=torch.bfloat16, | |
) | |
self.qk_normalization = qk_normalization | |
if qk_normalization: | |
self.q_normalization = torch.nn.LayerNorm(head_dim) | |
self.k_normalization = torch.nn.LayerNorm(head_dim) | |
self._register_load_state_dict_pre_hook(self.load_hook) | |
# This adapter makes sure we can load vanilla | |
# Llama checkpoints where wq, wk, and wv are | |
# not fused in a single parameter | |
def load_hook( | |
self, | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
): | |
if prefix + "wq.weight" in state_dict: | |
wq = state_dict.pop(prefix + "wq.weight") | |
wk = state_dict.pop(prefix + "wk.weight") | |
wv = state_dict.pop(prefix + "wv.weight") | |
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) | |
def forward( | |
self, | |
x: torch.Tensor, | |
cache: LayerCache, | |
attn_bias: AttnBias, | |
group: dist.ProcessGroup | None = None, | |
) -> torch.Tensor: | |
# x.shape is (sum(seq_lens), dim) | |
# | |
# Since we support heterogenous sequence | |
# lengths, the hidden states are all | |
# concatenated together along the usual | |
# sequence dimension. The attention below | |
# finds out where sequences start & end | |
# using the provided attention bias. | |
xqkv = self.wqkv(x) | |
xq = xqkv[:, : (self.n_local_heads * self.head_dim)] | |
xkv = xqkv[:, (self.n_local_heads * self.head_dim) :] | |
xk, xv = xkv.chunk(2, 1) | |
if self.qk_normalization: | |
xq = xq.view(-1, self.n_local_heads, self.head_dim) | |
xq = self.q_normalization(xq) | |
xq = xq.view(-1, self.n_local_heads * self.head_dim) | |
xk = xk.view(-1, self.n_local_kv_heads, self.head_dim) | |
xk = self.k_normalization(xk) | |
xk = xk.view(-1, self.n_local_kv_heads * self.head_dim) | |
output_shape = xq.shape | |
xq = xq.view(1, xq.shape[0], self.n_local_heads, self.head_dim) | |
xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, self.head_dim) | |
xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, self.head_dim) | |
cache_k, cache_v = cache | |
xq = rope_padded( | |
xq=xq, | |
xk=xk, | |
xv=xv, | |
cache_k=cache_k, | |
cache_v=cache_v, | |
attn_bias=attn_bias, | |
theta=self.rope_theta, | |
) | |
# Handle GQA | |
# Q shape: [B, M, Hkv, Hq // Hkv, K] | |
heads_per_group = self.n_local_heads // self.n_local_kv_heads | |
cache_k = cache_k.unsqueeze(3).expand(-1, -1, -1, heads_per_group, -1) | |
cache_v = cache_v.unsqueeze(3).expand(-1, -1, -1, heads_per_group, -1) | |
xq = xq.reshape( | |
[*xq.shape[:2], self.n_local_kv_heads, heads_per_group, xq.shape[-1]] | |
) | |
# rope_padded() updated the caches, so we | |
# call attention directly | |
output = fmha.memory_efficient_attention_forward( | |
xq, cache_k, cache_v, attn_bias | |
) | |
output = self.wo(output.reshape(output_shape)) | |
if self.model_parallel_size > 1: | |
dist.all_reduce(output, group=group) | |
return output | |
class FeedForward(nn.Module): | |
def __init__( | |
self, | |
model_parallel_size: int, | |
dim: int, | |
hidden_dim: int, | |
multiple_of: int, | |
ffn_dim_multiplier: float | None, | |
): | |
super().__init__() | |
self.model_parallel_size = model_parallel_size | |
hidden_dim = int(2 * hidden_dim / 3) | |
if ffn_dim_multiplier is not None: | |
hidden_dim = int(ffn_dim_multiplier * hidden_dim) | |
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | |
assert hidden_dim % model_parallel_size == 0 | |
self.w13 = nn.Linear( | |
dim, | |
2 * hidden_dim // model_parallel_size, | |
bias=False, | |
) | |
self.w2 = nn.Linear( | |
hidden_dim // model_parallel_size, | |
dim, | |
bias=False, | |
) | |
self._register_load_state_dict_pre_hook(self.load_hook) | |
# This adapter makes sure we can load vanilla | |
# Llama checkpoints where w1 and w3 are not | |
# fused in a single parameter | |
def load_hook( | |
self, | |
state_dict, | |
prefix, | |
local_metadata, | |
strict, | |
missing_keys, | |
unexpected_keys, | |
error_msgs, | |
): | |
if prefix + "w1.weight" in state_dict: | |
w1 = state_dict.pop(prefix + "w1.weight") | |
w3 = state_dict.pop(prefix + "w3.weight") | |
state_dict[prefix + "w13.weight"] = torch.cat([w1, w3]) | |
def forward( | |
self, x: torch.Tensor, group: dist.ProcessGroup | None = None | |
) -> torch.Tensor: | |
x13 = self.w13(x) | |
x1, x3 = x13.chunk(2, -1) | |
output = self.w2(F.silu(x1) * x3) | |
if self.model_parallel_size > 1: | |
dist.all_reduce(output, group=group) | |
return output | |
class TransformerBlock(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
assert args.dim % args.n_heads == 0 | |
head_dim = args.dim // args.n_heads | |
if args.n_kv_heads is not None: | |
n_kv_heads = args.n_kv_heads | |
else: | |
n_kv_heads = args.n_heads | |
model_parallel_size = args.model_parallel_size | |
assert args.n_heads % n_kv_heads == 0 | |
assert args.n_heads % model_parallel_size == 0 | |
assert n_kv_heads % model_parallel_size == 0 | |
self.attention = Attention( | |
model_parallel_size=model_parallel_size, | |
dim=args.dim, | |
head_dim=head_dim, | |
n_heads=args.n_heads, | |
n_kv_heads=n_kv_heads, | |
rope_theta=args.rope_theta, | |
qk_normalization=args.qk_normalization, | |
) | |
self.feed_forward = FeedForward( | |
model_parallel_size=model_parallel_size, | |
dim=args.dim, | |
hidden_dim=4 * args.dim, | |
multiple_of=args.multiple_of, | |
ffn_dim_multiplier=args.ffn_dim_multiplier, | |
) | |
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) | |
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) | |
self.swin_norm = args.swin_norm | |
def forward( | |
self, | |
x: torch.Tensor, | |
cache: LayerCache, | |
attn_bias: AttnBias, | |
group: dist.ProcessGroup | None = None, | |
) -> torch.Tensor: | |
if self.swin_norm: | |
h = x + self.attention_norm( | |
self.attention.forward( | |
x, | |
cache, | |
attn_bias, | |
group=group, | |
) | |
) | |
out = h + self.ffn_norm(self.feed_forward(h, group=group)) | |
else: | |
h = x + self.attention.forward( | |
self.attention_norm(x), | |
cache, | |
attn_bias, | |
group=group, | |
) | |
out = h + self.feed_forward(self.ffn_norm(h), group=group) | |
return out | |
class Transformer(nn.Module): | |
def __init__(self, args: ModelArgs): | |
super().__init__() | |
self.args = args | |
self.model_parallel_size = args.model_parallel_size | |
assert args.dim % self.model_parallel_size == 0 | |
assert args.vocab_size > 0 | |
assert args.vocab_size % self.model_parallel_size == 0 | |
self.tok_embeddings = nn.Embedding( | |
num_embeddings=args.vocab_size, | |
embedding_dim=args.dim // self.model_parallel_size, | |
) | |
self.layers = nn.ModuleList() | |
for _ in range(args.n_layers): | |
self.layers.append(TransformerBlock(args)) | |
self.norm = RMSNorm(args.dim, eps=args.norm_eps) | |
self.output = nn.Linear( | |
args.dim, | |
args.vocab_size // self.model_parallel_size, | |
bias=False, | |
) | |
def forward_with_attn_bias( | |
self, | |
token_values: torch.Tensor, | |
attn_bias: AttnBias, | |
cache: list[LayerCache], | |
group: dist.ProcessGroup | None = None, | |
) -> torch.Tensor: | |
h = self.tok_embeddings(token_values) | |
if self.model_parallel_size > 1: | |
gather = [torch.empty_like(h) for _ in range(self.model_parallel_size)] | |
dist.all_gather(gather, h, group=group) | |
h = torch.cat(gather, dim=-1) | |
for i, layer in enumerate(self.layers): | |
h = layer(h, cache[i], attn_bias, group=group) | |
logits = self.output(self.norm(h)) | |
if self.model_parallel_size > 1: | |
gather = [torch.empty_like(logits) for _ in range(self.model_parallel_size)] | |
dist.all_gather(gather, logits, group=group) | |
logits = torch.cat(gather, dim=-1) | |
return logits.float() | |
def forward( | |
self, | |
token_values: torch.Tensor, | |
token_lengths: torch.Tensor, | |
start_pos: torch.Tensor, | |
cache: list[LayerCache], | |
kv_padding: int, | |
group: dist.ProcessGroup | None = None, | |
) -> torch.Tensor: | |
attn_bias = AttnBias.from_seqlens( | |
q_seqlen=token_lengths.tolist(), | |
kv_seqlen=(start_pos + token_lengths).tolist(), | |
kv_padding=kv_padding, | |
) | |
return self.forward_with_attn_bias(token_values, attn_bias, cache, group=group) | |
def make_cache( | |
args: ModelArgs, | |
length: int, | |
device: str | torch.device | None = None, | |
n_layers: int | None = None, | |
dtype: torch.dtype | None = None, | |
) -> list[LayerCache]: | |
""" | |
Allocate a cache to be used with the Transformer module. | |
Args: | |
args (ModelArgs): the model configuration. | |
length (int): per layer cache size. | |
It is usually budgeted as ``max_batch * max_seq`` | |
device (torch.device, optional): the device on which | |
the cache should be allocated. | |
n_layers (int, optional): the number of layers to | |
allocate a cache for (defaults to the model | |
settings). | |
dtype (torch.dtype, optional): the dtype to use for | |
cache entries (defaults to the default dtype). | |
Returns: | |
The cache object to pass to ``Tranformer.forward``. | |
""" | |
head_dim = args.dim // args.n_heads | |
n_kv_heads = args.n_kv_heads | |
if n_kv_heads is None: | |
n_kv_heads = args.n_heads | |
n_local_kv_heads = n_kv_heads // args.model_parallel_size | |
if n_layers is None: | |
n_layers = args.n_layers | |
shape = (1, length, n_local_kv_heads, head_dim) | |
return [ | |
( | |
torch.zeros(shape, device=device, dtype=dtype), | |
torch.zeros(shape, device=device, dtype=dtype), | |
) | |
for _ in range(n_layers) | |
] | |
def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]: | |
""" | |
Take a prefix view of a larger cache. | |
The original cache object remains of identical size and valid | |
after the shrinked alias has been used. This function is useful | |
when a cache was allocated for a larger batch size than what is | |
necessary. | |
Args: | |
cache: the cache to take a view in. | |
length (int): the desired length | |
Returns: | |
A view in the input cache object. | |
""" | |
if len(cache) > 0: | |
assert cache[0][0].shape[1] >= length | |
return [(ck[:, :length], cv[:, :length]) for ck, cv in cache] | |