import torch import torch.nn as nn from torch.nn import functional as F import tqdm class Head(nn.Module): """One head of self-attention.""" def __init__(self, n_embd, head_size, block_size, dropout): super().__init__() self.key = nn.Linear(n_embd, head_size, bias=False) self.query = nn.Linear(n_embd, head_size, bias=False) self.value = nn.Linear(n_embd, head_size, bias=False) self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) self.dropout = nn.Dropout(dropout) def forward(self, x): B, T, C = x.shape k = self.key(x) q = self.query(x) wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5 wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) wei = F.softmax(wei, dim=-1) wei = self.dropout(wei) v = self.value(x) out = wei @ v return out class MultiHeadAttention(nn.Module): """Multiple heads of self-attention in parallel.""" def __init__(self, n_embd, n_head, block_size, dropout): super().__init__() assert n_embd % n_head == 0, f"n_embd ({n_embd}) must be divisible by num_heads ({n_head})" self.n_embd = n_embd self.n_head = n_head self.head_size = n_embd // n_head self.heads = nn.ModuleList([Head(n_embd, self.head_size, block_size, dropout) for _ in range(n_head)]) self.proj = nn.Linear(n_embd, n_embd) self.dropout = nn.Dropout(dropout) def forward(self, x): out = torch.cat([h(x) for h in self.heads], dim=-1) out = self.dropout(self.proj(out)) return out class FeedForward(nn.Module): """A simple linear layer followed by a non-linearity.""" def __init__(self, n_embd, dropout): super().__init__() self.net = nn.Sequential( nn.Linear(n_embd, 4 * n_embd), nn.ReLU(), nn.Linear(4 * n_embd, n_embd), nn.Dropout(dropout), ) def forward(self, x): return self.net(x) class Block(nn.Module): """Transformer block: communication followed by computation.""" def __init__(self, n_embd, n_head, block_size, dropout): super().__init__() self.sa = MultiHeadAttention(n_embd, n_head, block_size, dropout) self.ffwd = FeedForward(n_embd, dropout) self.ln1 = nn.LayerNorm(n_embd) self.ln2 = nn.LayerNorm(n_embd) def forward(self, x): x = x + self.sa(self.ln1(x)) x = x + self.ffwd(self.ln2(x)) return x class RoPE(nn.Module): """Rotary Positional Encoding (RoPE) layer.""" def __init__(self, embd_dim, max_freq=10): super().__init__() self.embd_dim = embd_dim self.max_freq = max_freq self.freqs = 2 ** torch.linspace(0, max_freq - 1, embd_dim // 2) * torch.pi self.inv_freqs = 1. / self.freqs def forward(self, x): x = x + torch.sin(x @ self.freqs) * self.inv_freqs x = x + torch.cos(x @ self.freqs) * self.inv_freqs return x class RMSNorm(nn.Module): """Root Mean Square Layer Normalization (RMSNorm).""" def __init__(self, embd_dim, epsilon=1e-8): super().__init__() self.embd_dim = embd_dim self.epsilon = epsilon self.gamma = nn.Parameter(torch.ones(embd_dim)) self.beta = nn.Parameter(torch.zeros(embd_dim)) def forward(self, x: torch.Tensor): mean = x.mean(-1, keepdim=True) variance = x.var(-1, keepdim=True) x = x - mean x = x / torch.sqrt(variance + self.epsilon) x = x * self.gamma + self.beta return x class LlamaFFN(nn.Module): """Feed-forward network of the LLAMA model with SwiGLU activation.""" def __init__(self, n_embd, dropout): super().__init__() self.linear1 = nn.Linear(n_embd, 4 * n_embd) self.linear2 = nn.Linear(4 * n_embd, n_embd) self.dropout = nn.Dropout(dropout) def swiglu(self, x): """Applies SwiGLU activation.""" x1, x2 = torch.chunk(x, 2, dim=-1) return x1 * F.silu(x2) def forward(self, x): x = self.linear1(x) x = self.swiglu(x) x = self.dropout(x) x = self.linear2(x) return x class AttentionHeadWithKVCacheAndRoPE(nn.Module): """One head of self-attention with key and value cache and RoPE.""" def __init__(self, n_embd, head_size, block_size, dropout): super().__init__() self.key = nn.Linear(n_embd, head_size, bias=False) self.query = nn.Linear(n_embd, head_size, bias=False) self.value = nn.Linear(n_embd, head_size, bias=False) self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) self.dropout = nn.Dropout(dropout) self.pe = RoPE(head_size) self.ln = RMSNorm(n_embd) def forward(self, x, kv_cache): B, T, C = x.shape k = self.key(x) q = self.query(x) v = self.value(x) if kv_cache is not None: k = torch.cat([kv_cache['k'], k], dim=1) v = torch.cat([kv_cache['v'], v], dim=1) wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5 wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) wei = F.softmax(wei, dim=-1) wei = self.dropout(wei) out = wei @ v if kv_cache is None: kv_cache = {'k': k, 'q': q, 'v': v} else: kv_cache['k'] = k kv_cache['q'] = q kv_cache['v'] = v return self.pe(out) + x class MultiHeadAttentionWithKVCacheAndRoPE(nn.Module): """Multiple heads of self-attention in parallel.""" def __init__(self, n_embd, n_head, block_size, dropout): super().__init__() assert n_embd % n_head == 0, f"n_embd ({n_embd}) must be divisible by num_heads ({n_head})" self.n_embd = n_embd self.n_head = n_head self.head_size = n_embd // n_head self.heads = nn.ModuleList([AttentionHeadWithKVCacheAndRoPE(n_embd, self.head_size, block_size, dropout) for _ in range(n_head)]) self.proj = nn.Linear(n_embd, n_embd) self.dropout = nn.Dropout(dropout) def forward(self, x, kv_cache): out = torch.cat([h(x, kv_cache) for h in self.heads], dim=-1) out = self.dropout(self.proj(out)) return out class LlamaBlock(nn.Module): """LLAMA block: communication followed by computation.""" def __init__(self, n_embd, n_head, block_size, dropout): super().__init__() self.ln1 = RMSNorm(n_embd) self.sa = MultiHeadAttentionWithKVCacheAndRoPE(n_embd, n_head, block_size, dropout) self.ln2 = RMSNorm(n_embd) self.ffwd = LlamaFFN(n_embd, dropout) def forward(self, x, kv_cache): x = x + self.sa(self.ln1(x), kv_cache) x = x + self.ffwd(self.ln2(x)) return x