|
|
|
import torch |
|
import torch.optim as optim |
|
from torch.nn import functional as F |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
import wandb |
|
from transformers import get_linear_schedule_with_warmup |
|
from utils.data_preprocessing import get_dataloader, load_tokenizer |
|
from models.gem_model import GEM |
|
from configs.config import MODEL_CONFIG, TRAINING_CONFIG |
|
|
|
def train(): |
|
wandb.init(project="GEM_Project", config=MODEL_CONFIG, mode="offline") |
|
print("WandB initialized in offline mode.") |
|
|
|
tokenizer = load_tokenizer() |
|
print("Tokenizer loaded.") |
|
|
|
dataloader = get_dataloader('wikitext', 'wikitext-2-raw-v1', tokenizer, MODEL_CONFIG['MAX_SEQ_LEN'], MODEL_CONFIG['BATCH_SIZE']) |
|
print("Dataloader created.") |
|
|
|
model = GEM( |
|
vocab_size=len(tokenizer), |
|
d_model=MODEL_CONFIG['D_MODEL'], |
|
n_heads=MODEL_CONFIG['N_HEADS'], |
|
d_ff=MODEL_CONFIG['D_FF'], |
|
n_layers=MODEL_CONFIG['N_LAYERS'], |
|
dropout=MODEL_CONFIG['DROPOUT'] |
|
).to(MODEL_CONFIG['DEVICE']) |
|
print("Model initialized.") |
|
|
|
optimizer = optim.AdamW(model.parameters(), lr=MODEL_CONFIG['LEARNING_RATE'], eps=MODEL_CONFIG['ADAM_EPSILON']) |
|
total_steps = len(dataloader) * MODEL_CONFIG['NUM_EPOCHS'] // MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS'] |
|
scheduler = get_linear_schedule_with_warmup( |
|
optimizer, |
|
num_warmup_steps=MODEL_CONFIG['WARMUP_STEPS'], |
|
num_training_steps=total_steps |
|
) |
|
print("Optimizer and scheduler set up.") |
|
|
|
|
|
scaler = torch.cuda.amp.GradScaler() |
|
|
|
model.train() |
|
print("Starting training loop.") |
|
for epoch in range(MODEL_CONFIG['NUM_EPOCHS']): |
|
print(f"Epoch {epoch + 1}/{MODEL_CONFIG['NUM_EPOCHS']} started.") |
|
for step, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}")): |
|
batch = batch.to(MODEL_CONFIG['DEVICE']) |
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
outputs = model(batch) |
|
loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), batch.view(-1)) |
|
|
|
|
|
loss = loss / MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS'] |
|
scaler.scale(loss).backward() |
|
|
|
if (step + 1) % MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS'] == 0: |
|
scaler.unscale_(optimizer) |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), MODEL_CONFIG['MAX_GRAD_NORM']) |
|
scaler.step(optimizer) |
|
scaler.update() |
|
scheduler.step() |
|
optimizer.zero_grad() |
|
|
|
if step % TRAINING_CONFIG['LOGGING_STEPS'] == 0: |
|
wandb.log({"loss": loss.item() * MODEL_CONFIG['GRADIENT_ACCUMULATION_STEPS']}) |
|
|
|
if step % TRAINING_CONFIG['EVAL_STEPS'] == 0: |
|
model.eval() |
|
with torch.no_grad(): |
|
val_loss = sum(F.cross_entropy(model(batch).view(-1, outputs.size(-1)), batch.view(-1)).item() for batch in dataloader) |
|
wandb.log({"val_loss": val_loss / len(dataloader)}) |
|
model.train() |
|
|
|
if step % TRAINING_CONFIG['CHECKPOINT_SAVE_STEPS'] == 0: |
|
torch.save(model.state_dict(), f"checkpoint_{epoch}_{step}.pt") |
|
|
|
torch.save(model.state_dict(), "GEM_1o_Aug_15.pt") |
|
print("Training complete. Final model saved.") |
|
|
|
if __name__ == "__main__": |
|
train() |
|
|