iupacGPT / iupac-gpt /iupac_gpt /language_modeling.py
mao jiashun
Upload 58 files
295ff14
"""Pytorch-lightning module for causal language modeling.
"""
__all__ = ("GPT2LitModel",)
import pytorch_lightning as pl
import torch
class GPT2LitModel(pl.LightningModule):
"""Lightning module for autoregressive (causal) transformer language modeling.
Successfully tested on HuggingFace `GPT2LMHeadModel`.
"""
def __init__(self, transformer, batch_size: int, learning_rate: float,
final_learning_rate: float, weight_decay: float, adam_eps: float,
adam_betas: tuple, scheduler_T_max: int,
save_model_every: int = 10_000, checkpoint: str = ""):
super().__init__()
self.save_hyperparameters(ignore=("transformer", "save_model_every",
"checkpoints"))
self.transformer = transformer
self.save_model_every = save_model_every
self.checkpoint = checkpoint or "./gpt2litmodel-logs"
def forward(self, *args, **kwargs):
return self.transformer(*args, **kwargs)
def training_step(self, batch, batch_idx):
outputs = self(**batch)
if self.save_model_every > 0 and batch_idx % self.save_model_every == 0:
self.transformer.save_pretrained(self.checkpoint)
return {'loss': outputs['loss']}
def training_epoch_end(self, outputs):
if self.save_model_every > 0:
self.transformer.save_pretrained(self.checkpoint)
losses = [step_output["loss"] for step_output in outputs]
mean_loss = torch.tensor(losses).mean()
ppl = torch.exp(mean_loss)
self.log("ppl", ppl, on_step=False, on_epoch=True, prog_bar=True)
def configure_optimizers(self):
parameters = self.named_parameters()
no_decay = ["bias", "LayerNorm.weight"]
grouped_parameters = [
{"params": [p for n, p in parameters
if not any(nd in n for nd in no_decay)],
"weight_decay": self.hparams.weight_decay},
{"params": [p for n, p in parameters
if any(nd in n for nd in no_decay)],
"weight_decay": 0.0}]
optimizer = torch.optim.Adam(
grouped_parameters, lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay,
eps=self.hparams.adam_eps, betas=self.hparams.adam_betas)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, self.hparams.scheduler_T_max,
eta_min=self.hparams.final_learning_rate)
return {'optimizer': optimizer,
'lr_scheduler': {'scheduler': lr_scheduler,
'interval': 'step', 'frequency': 1}}