"""Pytorch-lightning module for causal language modeling. """ __all__ = ("GPT2LitModel",) import pytorch_lightning as pl import torch class GPT2LitModel(pl.LightningModule): """Lightning module for autoregressive (causal) transformer language modeling. Successfully tested on HuggingFace `GPT2LMHeadModel`. """ def __init__(self, transformer, batch_size: int, learning_rate: float, final_learning_rate: float, weight_decay: float, adam_eps: float, adam_betas: tuple, scheduler_T_max: int, save_model_every: int = 10_000, checkpoint: str = ""): super().__init__() self.save_hyperparameters(ignore=("transformer", "save_model_every", "checkpoints")) self.transformer = transformer self.save_model_every = save_model_every self.checkpoint = checkpoint or "./gpt2litmodel-logs" def forward(self, *args, **kwargs): return self.transformer(*args, **kwargs) def training_step(self, batch, batch_idx): outputs = self(**batch) if self.save_model_every > 0 and batch_idx % self.save_model_every == 0: self.transformer.save_pretrained(self.checkpoint) return {'loss': outputs['loss']} def training_epoch_end(self, outputs): if self.save_model_every > 0: self.transformer.save_pretrained(self.checkpoint) losses = [step_output["loss"] for step_output in outputs] mean_loss = torch.tensor(losses).mean() ppl = torch.exp(mean_loss) self.log("ppl", ppl, on_step=False, on_epoch=True, prog_bar=True) def configure_optimizers(self): parameters = self.named_parameters() no_decay = ["bias", "LayerNorm.weight"] grouped_parameters = [ {"params": [p for n, p in parameters if not any(nd in n for nd in no_decay)], "weight_decay": self.hparams.weight_decay}, {"params": [p for n, p in parameters if any(nd in n for nd in no_decay)], "weight_decay": 0.0}] optimizer = torch.optim.Adam( grouped_parameters, lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay, eps=self.hparams.adam_eps, betas=self.hparams.adam_betas) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, self.hparams.scheduler_T_max, eta_min=self.hparams.final_learning_rate) return {'optimizer': optimizer, 'lr_scheduler': {'scheduler': lr_scheduler, 'interval': 'step', 'frequency': 1}}