ToyTransformer / modeling.py
larryvrh's picture
Update for fix.
73b1ea5
raw
history blame
No virus
12 kB
import torch
from torch import nn
from dataclasses import dataclass
from enum import Enum
from typing import *
from math import ceil
class AttentionBackend(Enum):
Naive = 0
FlashAttentionCuda = 1
FlashAttentionTriton = 2
global_config = {
'attn_backend': AttentionBackend.Naive
}
@dataclass
class TransformerConfig:
vocab_size: int = -1,
num_layers: int = -1,
num_heads: int = -1,
hidden_size: int = -1,
max_seq_len: int = -1,
root_model: 'ToyTransformer' = None
device: torch.device = torch.device('cpu')
dtype: torch.dtype = torch.float32
def expand_attn_mask(custom_attn_mask: torch.Tensor):
B, T = custom_attn_mask.shape
mask = custom_attn_mask.unsqueeze(1).repeat((1, T, 1))
seq_index_mask = (mask == custom_attn_mask[:, torch.arange(T)].view(B, T, 1))
return seq_index_mask & (torch.tril(mask) > 0)
# expand attn mask to cu_seqlens for flash attn
def expand_attn_mask_to_seq_lengths(attn_mask: torch.Tensor):
attn_mask = attn_mask.to('cpu')
seq_len = attn_mask.shape[0] * attn_mask.shape[1]
disjoint_point = torch.cat([torch.tensor([[True]] * attn_mask.shape[0]), attn_mask[:, 1:] != attn_mask[:, :-1]], dim=1)
return torch.cat([torch.nonzero(disjoint_point.view((-1,))), torch.tensor([[seq_len]])]).to(dtype=torch.int32)
# naive RoPE implementation following https://arxiv.org/pdf/2104.09864.pdf
def get_rope_cache_slow(seq_len: int, dim: int, theta: int, device: torch.device, dtype: torch.dtype):
assert dim % 2 == 0
freqs = theta ** (-2 * torch.arange(0, dim // 2, 1.) / dim)
freqs = torch.repeat_interleave(freqs, 2)
v1 = torch.cos(torch.arange(seq_len, dtype=torch.float).view((seq_len, 1)) * freqs)
v2 = torch.sin(torch.arange(seq_len, dtype=torch.float).view((seq_len, 1)) * freqs)
v2 = v2 * torch.tensor([1, -1] * (dim // 2))
indices = torch.tensor([j for i in range(0, dim, 2) for j in (i + 1, i)])
return v1.to(device, dtype=dtype), v2.to(device, dtype=dtype), indices.to(device)
def apply_rope_slow(x, rope_cache, positions: Optional[torch.Tensor] = None):
v1, v2, indices = rope_cache
seq_len, dim = x.shape[1:]
if positions is None:
v1 = v1[:seq_len, :]
v2 = v2[:seq_len, :]
else:
v1 = v1[positions, torch.arange(dim)].view((-1, dim))
v2 = v2[positions, torch.arange(dim)].view((-1, dim))
applied_x = x * v1 + (x * v2)[:, :, indices]
return applied_x
# Optimized RoPE implementation adapted from https://github.com/facebookresearch/llama/blob/main/llama/model.py
def get_rope_cache_fast(seq_len: int, dim: int, theta: int, device: torch.device, dtype: torch.dtype):
freqs = (1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)))
t = torch.arange(seq_len, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis.to(device)
def apply_rope_fast(x, rope_cache, positions: Optional[torch.Tensor] = None) -> torch.Tensor:
x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
if positions is None and x.shape[1] < rope_cache.shape[0]:
freqs_cis = rope_cache[:x.shape[1], :]
elif positions is not None:
freqs_cis = rope_cache[positions, :]
else:
freqs_cis = rope_cache
freqs_cis = freqs_cis.view([d if i == 1 or i == x_.ndim - 1 else 1 for i, d in enumerate(x_.shape)])
applied_x = torch.view_as_real(x_ * freqs_cis).flatten(2)
return applied_x.type_as(x)
# RMSNorm implementation following https://arxiv.org/pdf/1910.07467.pdf
class RMSNorm(nn.Module):
def __init__(self, hidden_size, dtype, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size, dtype=dtype))
self.eps = eps
def forward(self, x: torch.Tensor):
x_ = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return self.weight * x_
class AttentionHead(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.config = config
self.head_size = config.hidden_size // config.num_heads
self.dtype = config.dtype
self.q_proj = nn.Linear(config.hidden_size, self.head_size, dtype=config.dtype)
self.k_proj = nn.Linear(config.hidden_size, self.head_size, dtype=config.dtype)
self.v_proj = nn.Linear(config.hidden_size, self.head_size, dtype=config.dtype)
def forward(self, x: torch.Tensor, attn_masked_bias: Optional[torch.Tensor],
kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# if global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton:
# padding the position indices for alignment
# positions = torch.tensor([kv_cache[0].shape[1]] * q.shape[1]).to(q.device) if kv_cache is not None else torch.arange(0, x.shape[1], 1).to(q.device)
positions = torch.tensor([kv_cache[0].shape[1]]).to(q.device) if kv_cache is not None else None
q = apply_rope_fast(q, self.config.root_model.rope_cache, positions)
k = apply_rope_fast(k, self.config.root_model.rope_cache, positions)
if kv_cache is not None:
k = torch.concat([kv_cache[0], k], dim=1)
v = torch.concat([kv_cache[1], v], dim=1)
if global_config['attn_backend'] == AttentionBackend.FlashAttentionCuda:
q, k, v, = q.unsqueeze(2), k.unsqueeze(2), v.unsqueeze(2)
attn_result = flash_attn_func(q, k, v, causal=True)
q, k, v, attn_result = q.squeeze(2), k.squeeze(2), v.squeeze(2), attn_result.squeeze(2)
elif global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton:
q, k, v, = q.unsqueeze(2), k.unsqueeze(2), v.unsqueeze(2)
attn_result = flash_attn_func_triton(q, k, v, attn_masked_bias.unsqueeze(1) if attn_masked_bias is not None else None,
True if kv_cache is None else False)
q, k, v, attn_result = q.squeeze(2), k.squeeze(2), v.squeeze(2), attn_result.squeeze(2)
else:
attn_score = (q @ k.permute(0, 2, 1) / (self.head_size ** 0.5)) + attn_masked_bias
attn_result = torch.softmax(attn_score, dim=2) @ v
return attn_result, [k, v]
class MultiHeadAttention(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.config = config
self.attn_heads = nn.ModuleList([AttentionHead(config) for _ in range(config.num_heads)])
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, dtype=config.dtype)
def forward(self, x: torch.Tensor, attn_masked_bias: Optional[torch.Tensor],
kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
head_outputs = [head(x, attn_masked_bias, kv_cache[idx] if kv_cache is not None else None) for idx, head in
enumerate(self.attn_heads)]
return self.o_proj(torch.concat([o[0] for o in head_outputs], dim=2)), [o[1] for o in head_outputs]
class DecoderLayer(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.config = config
self.mha = MultiHeadAttention(config)
self.up_proj = nn.Linear(config.hidden_size, config.hidden_size * 4, dtype=config.dtype)
self.down_proj = nn.Linear(config.hidden_size * 4, config.hidden_size, dtype=config.dtype)
self.ln_mha = nn.LayerNorm(config.hidden_size, dtype=config.dtype)
self.ln_ffn = nn.LayerNorm(config.hidden_size, dtype=config.dtype)
self.act = nn.GELU()
def forward(self, x: torch.Tensor, attn_masked_bias: Optional[torch.Tensor],
kv_cache: Optional[List[torch.Tensor]]) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]:
mha_output, new_kv_cache = self.mha(self.ln_mha(x), attn_masked_bias, kv_cache)
mha_output = x + mha_output
ffn_output = self.down_proj(self.act(self.up_proj(self.ln_ffn(mha_output))))
return mha_output + ffn_output, new_kv_cache
class ToyTransformer(nn.Module):
def __init__(self, vocab_size: int, num_layers: int, num_heads: int, hidden_size: int, max_seq_len: int,
device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float32):
super().__init__()
self.config = TransformerConfig(vocab_size, num_layers, num_heads, hidden_size, max_seq_len, self, device,
dtype)
self.sem_embed = nn.Embedding(vocab_size, hidden_size, dtype=dtype)
self.rope_cache = get_rope_cache_fast(max_seq_len, hidden_size // num_heads, 10000, device, dtype)
self.decoder_layers = nn.ModuleList([DecoderLayer(self.config) for _ in range(num_layers)])
self.lm_head = nn.Linear(hidden_size, vocab_size, dtype=dtype)
self.to(device)
def forward(self, seq: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[List[List[torch.Tensor]]]]:
# sanity checks
assert attn_mask is None or kv_cache is None # No support for attn_mask and kv_cache both enabled
if kv_cache is not None:
assert seq.shape[0] == 1, 'kv_cache is not supported for batch inference'
# handle flash-attn triton alignment requirement (actually only needed for backward)
seq_length = seq.shape[1]
if kv_cache is None and global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton and seq_length % 128 != 0:
if attn_mask is None: # forcibly enable attn_mask due to padding
attn_mask = torch.ones(seq.shape, device=self.device)
pad_length = (ceil(seq_length / 128) * 128) - seq_length
seq = nn.functional.pad(seq, (0, pad_length))
attn_mask = nn.functional.pad(attn_mask, (0, pad_length))
# handle attn_bias
if global_config['attn_backend'] == AttentionBackend.FlashAttentionCuda:
assert attn_mask is None, 'FlashAttn-Cuda does not support custom attn_mask'
attn_masked_bias = None
elif global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton and attn_mask is None:
attn_masked_bias = None
elif attn_mask is not None:
attn_masked_bias = expand_attn_mask(attn_mask)
elif attn_mask is None and kv_cache is None:
attn_masked_bias = expand_attn_mask(torch.ones(seq.shape, device=self.device))
elif kv_cache is not None:
attn_masked_bias = torch.ones((1, seq.shape[1], seq.shape[1]), dtype=torch.bool, device=self.device)
else:
attn_masked_bias = None
if attn_masked_bias is not None:
mask_zero = torch.tensor(0, dtype=self.config.dtype)
mask_val = torch.tensor(torch.finfo(self.config.dtype).min / 2, dtype=self.config.dtype)
attn_masked_bias = torch.where(attn_masked_bias, mask_zero, mask_val).to(self.device)
hidden = self.sem_embed(seq)
new_kv_cache = []
for idx, decoder in enumerate(self.decoder_layers):
hidden, layer_kv_cache = decoder(hidden, attn_masked_bias, kv_cache[idx] if kv_cache is not None else None)
new_kv_cache.append(layer_kv_cache)
logits = self.lm_head(hidden)
# remove padding for flash-attn triton
if kv_cache is None and global_config['attn_backend'] == AttentionBackend.FlashAttentionTriton and seq_length % 128 != 0:
logits = logits[:, :seq_length, :]
new_kv_cache = [[[cache[:, :seq_length, :] for cache in head] for head in layer] for layer in new_kv_cache]
return logits, new_kv_cache
@property
def device(self):
return next(self.parameters()).device