File size: 8,156 Bytes
b9f2647 5556482 b9f2647 5556482 b9f2647 5556482 b9f2647 5556482 b9f2647 5556482 b9f2647 5556482 b9f2647 5556482 b9f2647 5556482 b9f2647 5556482 b9f2647 5556482 b9f2647 5556482 b9f2647 5556482 b9f2647 5556482 b9f2647 5556482 b9f2647 5556482 |
|
"""
Standard Transformer baseline for comparison with DTAT
Based on NanoGPT architecture with optimizations for enwik8
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
self.head_size = config.n_embd // config.n_head
# Key, Query, Value projections
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# Regularization
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
# Flash attention optimization if available
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: Flash Attention not available, using manual attention")
# Manual causal mask
self.register_buffer(
"bias",
torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size)
)
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality
# Calculate query, key, values
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
# Causal self-attention with memory optimization
if self.flash:
# Use flash attention if available (faster and more memory efficient)
with torch.backends.cuda.sdp_kernel(enable_flash=True):
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0,
is_causal=True
)
else:
# Manual attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v
# Reshape and project back
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.c_proj(y))
return y
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.dropout),
)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class BaselineTransformer(nn.Module):
def __init__(self, config):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),
drop = nn.Dropout(config.dropout),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = nn.LayerNorm(config.n_embd)
))
# Language modeling head
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initialize weights
self.apply(self._init_weights)
# Apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
# Report number of parameters
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
# Gradient checkpointing flag
self.gradient_checkpointing = False
def gradient_checkpointing_enable(self):
"""Enable gradient checkpointing for memory efficiency"""
self.gradient_checkpointing = True
def gradient_checkpointing_disable(self):
"""Disable gradient checkpointing"""
self.gradient_checkpointing = False
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
# Token and position embeddings
tok_emb = self.transformer.wte(idx)
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0)
pos_emb = self.transformer.wpe(pos)
# Add embeddings and apply dropout
x = self.transformer.drop(tok_emb + pos_emb)
# Apply transformer blocks with optional gradient checkpointing
if self.gradient_checkpointing and self.training:
for block in self.transformer.h:
x = torch.utils.checkpoint.checkpoint(block, x)
else:
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
# Language model head
logits = self.lm_head(x)
# Loss calculation (in BPC)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
loss = loss / math.log(2) # Convert to BPC
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
# append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1)
return idx
def get_num_params(self, non_embedding=True):
n_params = sum(p.numel() for p in self.parameters())
if non_embedding:
n_params -= self.transformer.wpe.weight.numel()
return n_params
|