bad-gpt / bad_gpt.py
shamashel's picture
Whoops
a3df5a4
import os
from typing import Literal
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import logging
from encoder import encode, decode
from self_attention import Block
from dataset import Batcher
logger = logging.getLogger('bad_gpt').getChild(__name__)
class BadGPTModel(nn.Module):
def __init__(
self,
device: Literal['cuda', 'cpu'],
block_size: int,
vocab_size: int,
n_embd: int,
n_head: int = 4,
n_layers: int = 3,
dropout: float = 0.2
):
super().__init__()
self.block_size = block_size
self.vocab_size = vocab_size
self.n_embd = n_embd
self.device = device
# Create a table to embed both token and position
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
self.expected_loss: np.float64 = np.log(1/vocab_size) * -1
self.blocks = nn.Sequential(
*[
Block(n_embd, block_size, n_head, dropout)
for _ in range(n_layers)
],
nn.LayerNorm(n_embd)
)
def forward(self, idx: torch.Tensor, targets: torch.Tensor = None):
# Predict next tokens
B, T = idx.shape
tok_emb: torch.Tensor = self.token_embedding_table(idx)
pos_emb = self.position_embedding_table(
torch.arange(T, device=self.device))
x: torch.Tensor = tok_emb + pos_emb
x = self.blocks(x)
logits: torch.Tensor = self.lm_head(x)
return logits
# Given a 2d matrix of dimensions token and sentence
# generate new tokens in the next sentence
def generate(self, ctx: torch.Tensor, max_new_tokens: int):
for index in range(max_new_tokens):
# Log progress so I don't go insane
if index % 16 == 0:
logger.debug(f'Iteration {index} of {max_new_tokens}')
# Crop out the last block_size tokens
cropped_ctx = ctx[:, -self.block_size:]
logits = self(cropped_ctx)
# Logits has dimensions token, sentence, token_list
# We want to make a new sentence, so only look at the last sentence
logits = logits[:, -1, :]
# Get possible next tokens and select one
probabilities = F.softmax(logits, dim=-1)
ctx_next = torch.multinomial(probabilities, num_samples=1)
# Add the new token to the end of the tensor
ctx = torch.cat((ctx, ctx_next), dim=1)
return ctx
@torch.no_grad()
def estimate_loss(gpt: BadGPTModel, batcher: Batcher, eval_interval: int, device: Literal['cuda', 'cpu'] = 'cuda'):
out = {}
gpt.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_interval)
for epoch in range(eval_interval):
train, answer = batcher.get_batch(split='train')
logits = gpt(train)
# Reformat pediction and answer so each entry can be compared
batch, block, vocab = logits.shape
logits = logits.view(batch * block, vocab)
answer = answer.view(batch * block)
# Compare entropy of predicted tokens to actual
loss = F.cross_entropy(logits, answer).item()
losses[epoch] = loss
out[split] = losses.mean()
gpt.train()
return out
class BadGPTTrainer():
def __init__(self, model: BadGPTModel, batcher: Batcher, eval_interval: int, iterations: int, learning_rate: float):
self.model = model
self.batcher = batcher
self.eval_interval = eval_interval
self.iterations = iterations
self.learning_rate = learning_rate
self.device = self.model.device
self.optimizer = torch.optim.AdamW(
self.model.parameters(), lr=self.learning_rate)
def train(self):
if os.path.exists('model.pth'):
logger.debug("Loading model from file...")
checkpoint = torch.load('model.pth', map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
logger.debug("Model loaded!")
else:
logger.debug("Training model...")
self._train()
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict()
}, 'model.pth')
logger.debug("Training complete!")
def _train(self):
for i in range(self.iterations):
if i % self.eval_interval == 0:
losses = estimate_loss(
self.model, self.batcher, self.eval_interval, self.device)
logger.debug(
f"step {i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
context_stack, answer_stack = self.batcher.get_batch(split='train')
logits = self.model(context_stack.to(
self.device), answer_stack.to(self.device))
batch, block, vocab = logits.shape
# Reformat logits and val so each entry can be compared
logits = logits.view(batch * block, vocab).to(self.device)
answer_stack = answer_stack.view(batch * block).to(self.device)
# Compare predicted tokens to actual
loss = F.cross_entropy(logits, answer_stack)
self.optimizer.zero_grad(set_to_none=True)
loss.backward()
self.optimizer.step()
class BadGPT():
def __init__(
self,
device: Literal['cuda', 'cpu'],
block_size: int,
batch_size: int,
n_embd: int,
n_head: int,
n_layers: int,
dropout: float,
eval_interval: int,
iterations: int,
lr: float
):
self.device = device
self._batcher = Batcher(
device=device,
batch_size=batch_size,
block_size=block_size
)
self._model = BadGPTModel(
device=device,
block_size=block_size,
vocab_size=len(self._batcher.vocab),
n_embd=n_embd,
n_head=n_head,
n_layers=n_layers,
dropout=dropout
).to(device)
self._trainer = BadGPTTrainer(
model=self._model,
batcher=self._batcher,
eval_interval=eval_interval,
iterations=iterations,
learning_rate=lr
)
self._trainer.train()
# set to eval phase since we're only taking user input from here on
self._model.eval()
def generate(self, prompt: str, response_size: int):
start_ids = encode(prompt)
context = torch.tensor(start_ids, dtype=torch.long, device=self.device)
# add batch dimension. it's just 1 batch, but we still need it cuz tensors
context = context[None, ...]
encoded = self._model.generate(
ctx=context, max_new_tokens=response_size)[0]
return decode(encoded.tolist())