|
import os |
|
from typing import Literal |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
import numpy as np |
|
import logging |
|
|
|
from encoder import encode, decode |
|
from self_attention import Block |
|
from dataset import Batcher |
|
|
|
logger = logging.getLogger('bad_gpt').getChild(__name__) |
|
|
|
|
|
class BadGPTModel(nn.Module): |
|
def __init__( |
|
self, |
|
device: Literal['cuda', 'cpu'], |
|
block_size: int, |
|
vocab_size: int, |
|
n_embd: int, |
|
n_head: int = 4, |
|
n_layers: int = 3, |
|
dropout: float = 0.2 |
|
): |
|
super().__init__() |
|
self.block_size = block_size |
|
self.vocab_size = vocab_size |
|
self.n_embd = n_embd |
|
self.device = device |
|
|
|
self.token_embedding_table = nn.Embedding(vocab_size, n_embd) |
|
self.position_embedding_table = nn.Embedding(block_size, n_embd) |
|
self.lm_head = nn.Linear(n_embd, vocab_size) |
|
self.expected_loss: np.float64 = np.log(1/vocab_size) * -1 |
|
self.blocks = nn.Sequential( |
|
*[ |
|
Block(n_embd, block_size, n_head, dropout) |
|
for _ in range(n_layers) |
|
], |
|
nn.LayerNorm(n_embd) |
|
) |
|
|
|
def forward(self, idx: torch.Tensor, targets: torch.Tensor = None): |
|
|
|
B, T = idx.shape |
|
tok_emb: torch.Tensor = self.token_embedding_table(idx) |
|
pos_emb = self.position_embedding_table( |
|
torch.arange(T, device=self.device)) |
|
x: torch.Tensor = tok_emb + pos_emb |
|
x = self.blocks(x) |
|
logits: torch.Tensor = self.lm_head(x) |
|
return logits |
|
|
|
|
|
|
|
def generate(self, ctx: torch.Tensor, max_new_tokens: int): |
|
for index in range(max_new_tokens): |
|
|
|
if index % 16 == 0: |
|
logger.debug(f'Iteration {index} of {max_new_tokens}') |
|
|
|
cropped_ctx = ctx[:, -self.block_size:] |
|
logits = self(cropped_ctx) |
|
|
|
|
|
logits = logits[:, -1, :] |
|
|
|
probabilities = F.softmax(logits, dim=-1) |
|
ctx_next = torch.multinomial(probabilities, num_samples=1) |
|
|
|
ctx = torch.cat((ctx, ctx_next), dim=1) |
|
return ctx |
|
|
|
|
|
@torch.no_grad() |
|
def estimate_loss(gpt: BadGPTModel, batcher: Batcher, eval_interval: int, device: Literal['cuda', 'cpu'] = 'cuda'): |
|
out = {} |
|
gpt.eval() |
|
for split in ['train', 'val']: |
|
losses = torch.zeros(eval_interval) |
|
for epoch in range(eval_interval): |
|
train, answer = batcher.get_batch(split='train') |
|
logits = gpt(train) |
|
|
|
batch, block, vocab = logits.shape |
|
logits = logits.view(batch * block, vocab) |
|
answer = answer.view(batch * block) |
|
|
|
loss = F.cross_entropy(logits, answer).item() |
|
losses[epoch] = loss |
|
out[split] = losses.mean() |
|
gpt.train() |
|
return out |
|
|
|
|
|
class BadGPTTrainer(): |
|
def __init__(self, model: BadGPTModel, batcher: Batcher, eval_interval: int, iterations: int, learning_rate: float): |
|
self.model = model |
|
self.batcher = batcher |
|
self.eval_interval = eval_interval |
|
self.iterations = iterations |
|
self.learning_rate = learning_rate |
|
self.device = self.model.device |
|
self.optimizer = torch.optim.AdamW( |
|
self.model.parameters(), lr=self.learning_rate) |
|
|
|
def train(self): |
|
if os.path.exists('model.pth'): |
|
logger.debug("Loading model from file...") |
|
checkpoint = torch.load('model.pth', map_location=self.device) |
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
logger.debug("Model loaded!") |
|
else: |
|
logger.debug("Training model...") |
|
self._train() |
|
torch.save({ |
|
'model_state_dict': self.model.state_dict(), |
|
'optimizer_state_dict': self.optimizer.state_dict() |
|
}, 'model.pth') |
|
logger.debug("Training complete!") |
|
|
|
def _train(self): |
|
for i in range(self.iterations): |
|
if i % self.eval_interval == 0: |
|
losses = estimate_loss( |
|
self.model, self.batcher, self.eval_interval, self.device) |
|
logger.debug( |
|
f"step {i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") |
|
context_stack, answer_stack = self.batcher.get_batch(split='train') |
|
logits = self.model(context_stack.to( |
|
self.device), answer_stack.to(self.device)) |
|
batch, block, vocab = logits.shape |
|
|
|
logits = logits.view(batch * block, vocab).to(self.device) |
|
answer_stack = answer_stack.view(batch * block).to(self.device) |
|
|
|
loss = F.cross_entropy(logits, answer_stack) |
|
self.optimizer.zero_grad(set_to_none=True) |
|
loss.backward() |
|
self.optimizer.step() |
|
|
|
|
|
class BadGPT(): |
|
def __init__( |
|
self, |
|
device: Literal['cuda', 'cpu'], |
|
block_size: int, |
|
batch_size: int, |
|
n_embd: int, |
|
n_head: int, |
|
n_layers: int, |
|
dropout: float, |
|
eval_interval: int, |
|
iterations: int, |
|
lr: float |
|
): |
|
self.device = device |
|
self._batcher = Batcher( |
|
device=device, |
|
batch_size=batch_size, |
|
block_size=block_size |
|
) |
|
self._model = BadGPTModel( |
|
device=device, |
|
block_size=block_size, |
|
vocab_size=len(self._batcher.vocab), |
|
n_embd=n_embd, |
|
n_head=n_head, |
|
n_layers=n_layers, |
|
dropout=dropout |
|
).to(device) |
|
self._trainer = BadGPTTrainer( |
|
model=self._model, |
|
batcher=self._batcher, |
|
eval_interval=eval_interval, |
|
iterations=iterations, |
|
learning_rate=lr |
|
) |
|
self._trainer.train() |
|
|
|
self._model.eval() |
|
|
|
def generate(self, prompt: str, response_size: int): |
|
start_ids = encode(prompt) |
|
context = torch.tensor(start_ids, dtype=torch.long, device=self.device) |
|
|
|
context = context[None, ...] |
|
encoded = self._model.generate( |
|
ctx=context, max_new_tokens=response_size)[0] |
|
return decode(encoded.tolist()) |
|
|