from typing import List, Tuple import numpy as np import pytorch_lightning as pl import sentencepiece as spm import torch import torch.nn as nn #from sacrebleu.metrics.bleu import BLEU, _get_tokenizer from torch import optim from torch.nn.init import xavier_uniform_ from transformers.models.bert.tokenization_bert import BertTokenizer #import wandb from dataset import Batched, DecodedBatch #from models.scheduler import WarmupDecayLR from transformer import Decoder, Encoder #from kobe.utils import helpers class KobeModel(pl.LightningModule): def __init__(self, args): super(KobeModel, self).__init__() self.encoder = Encoder( vocab_size=args.text_vocab_size + args.cond_vocab_size, max_seq_len=args.max_seq_len, d_model=args.d_model, nhead=args.nhead, num_layers=args.num_encoder_layers, dropout=args.dropout, mode=args.mode, ) self.decoder = Decoder( vocab_size=args.text_vocab_size, max_seq_len=args.max_seq_len, d_model=args.d_model, nhead=args.nhead, num_layers=args.num_decoder_layers, dropout=args.dropout, ) self.lr = args.lr self.d_model = args.d_model self.loss = nn.CrossEntropyLoss( reduction="mean", ignore_index=0, label_smoothing=0.1 ) self._reset_parameters() self.decoding_strategy = args.decoding_strategy self.vocab = BertTokenizer.from_pretrained(args.text_vocab_path) #self.bleu = BLEU(tokenize=args.tokenize) #self.sacre_tokenizer = _get_tokenizer(args.tokenize)() #self.bert_scorer = BERTScorer(lang=args.tokenize, rescale_with_baseline=True) def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: xavier_uniform_(p) def _tokenwise_loss_acc( self, logits: torch.Tensor, batch: Batched ) -> Tuple[torch.Tensor, float]: unmask = ~batch.description_token_ids_mask.T[1:] unmasked_logits = logits[unmask] unmasked_targets = batch.description_token_ids[1:][unmask] #acc = helpers.accuracy(unmasked_logits, unmasked_targets) return self.loss(logits.transpose(1, 2), batch.description_token_ids[1:]), 1 def training_step(self, batch: Batched, batch_idx: int): encoded = self.encoder.forward(batch) logits = self.decoder.forward(batch, encoded) loss, acc = self._tokenwise_loss_acc(logits, batch) self.lr_schedulers().step() self.log("train/loss", loss.item()) self.log("train/acc", acc) return loss def _shared_eval_step(self, batch: Batched, batch_idx: int) -> DecodedBatch: encoded = self.encoder.forward(batch) #logits = self.decoder.forward(batch, encoded) #loss, acc = self._tokenwise_loss_acc(logits, batch) preds = self.decoder.predict( encoded_batch=encoded, decoding_strategy=self.decoding_strategy ) generated = self.vocab.batch_decode(preds.T.tolist(), skip_special_tokens=True) #print(generated) return generated return DecodedBatch( loss=loss.item(), acc=acc, generated=generated, descriptions=batch.descriptions, titles=batch.titles, ) def validation_step(self, batch, batch_idx): return self._shared_eval_step(batch, batch_idx) def test_step(self, batch, batch_idx, dataloader_idx=0): return self._shared_eval_step(batch, batch_idx) def _shared_epoch_end(self, outputs: List[DecodedBatch], prefix): loss = np.mean([o.loss for o in outputs]) acc = np.mean([o.acc for o in outputs]) self.log(f"{prefix}/loss", loss) self.log(f"{prefix}/acc", acc) print(outputs) generated = [g for o in outputs for g in o.generated] references = [r for o in outputs for r in o.descriptions] titles = [r for o in outputs for r in o.titles] # Examples columns = ["Generated", "Reference"] data = list(zip(generated[:256:16], references[:256:16])) table = wandb.Table(data=data, columns=columns) self.logger.experiment.log({f"examples/{prefix}": table}) def validation_epoch_end(self, outputs): self._shared_epoch_end(outputs, "val") def test_epoch_end(self, outputs): self._shared_epoch_end(outputs, "test") def configure_optimizers(self): optimizer = optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.98)) #scheduler = WarmupDecayLR(optimizer, warmup_steps=10000, d_model=self.d_model) return [optimizer]