Spaces:
Runtime error
Runtime error
from typing import List, Tuple | |
import numpy as np | |
import pytorch_lightning as pl | |
import sentencepiece as spm | |
import torch | |
import torch.nn as nn | |
#from sacrebleu.metrics.bleu import BLEU, _get_tokenizer | |
from torch import optim | |
from torch.nn.init import xavier_uniform_ | |
from transformers.models.bert.tokenization_bert import BertTokenizer | |
#import wandb | |
from dataset import Batched, DecodedBatch | |
#from models.scheduler import WarmupDecayLR | |
from transformer import Decoder, Encoder | |
#from kobe.utils import helpers | |
class KobeModel(pl.LightningModule): | |
def __init__(self, args): | |
super(KobeModel, self).__init__() | |
self.encoder = Encoder( | |
vocab_size=args.text_vocab_size + args.cond_vocab_size, | |
max_seq_len=args.max_seq_len, | |
d_model=args.d_model, | |
nhead=args.nhead, | |
num_layers=args.num_encoder_layers, | |
dropout=args.dropout, | |
mode=args.mode, | |
) | |
self.decoder = Decoder( | |
vocab_size=args.text_vocab_size, | |
max_seq_len=args.max_seq_len, | |
d_model=args.d_model, | |
nhead=args.nhead, | |
num_layers=args.num_decoder_layers, | |
dropout=args.dropout, | |
) | |
self.lr = args.lr | |
self.d_model = args.d_model | |
self.loss = nn.CrossEntropyLoss( | |
reduction="mean", ignore_index=0, label_smoothing=0.1 | |
) | |
self._reset_parameters() | |
self.decoding_strategy = args.decoding_strategy | |
self.vocab = BertTokenizer.from_pretrained(args.text_vocab_path) | |
#self.bleu = BLEU(tokenize=args.tokenize) | |
#self.sacre_tokenizer = _get_tokenizer(args.tokenize)() | |
#self.bert_scorer = BERTScorer(lang=args.tokenize, rescale_with_baseline=True) | |
def _reset_parameters(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
xavier_uniform_(p) | |
def _tokenwise_loss_acc( | |
self, logits: torch.Tensor, batch: Batched | |
) -> Tuple[torch.Tensor, float]: | |
unmask = ~batch.description_token_ids_mask.T[1:] | |
unmasked_logits = logits[unmask] | |
unmasked_targets = batch.description_token_ids[1:][unmask] | |
#acc = helpers.accuracy(unmasked_logits, unmasked_targets) | |
return self.loss(logits.transpose(1, 2), batch.description_token_ids[1:]), 1 | |
def training_step(self, batch: Batched, batch_idx: int): | |
encoded = self.encoder.forward(batch) | |
logits = self.decoder.forward(batch, encoded) | |
loss, acc = self._tokenwise_loss_acc(logits, batch) | |
self.lr_schedulers().step() | |
self.log("train/loss", loss.item()) | |
self.log("train/acc", acc) | |
return loss | |
def _shared_eval_step(self, batch: Batched, batch_idx: int) -> DecodedBatch: | |
encoded = self.encoder.forward(batch) | |
#logits = self.decoder.forward(batch, encoded) | |
#loss, acc = self._tokenwise_loss_acc(logits, batch) | |
preds = self.decoder.predict( | |
encoded_batch=encoded, decoding_strategy=self.decoding_strategy | |
) | |
generated = self.vocab.batch_decode(preds.T.tolist(), skip_special_tokens=True) | |
#print(generated) | |
return generated | |
return DecodedBatch( | |
loss=loss.item(), | |
acc=acc, | |
generated=generated, | |
descriptions=batch.descriptions, | |
titles=batch.titles, | |
) | |
def validation_step(self, batch, batch_idx): | |
return self._shared_eval_step(batch, batch_idx) | |
def test_step(self, batch, batch_idx, dataloader_idx=0): | |
return self._shared_eval_step(batch, batch_idx) | |
def _shared_epoch_end(self, outputs: List[DecodedBatch], prefix): | |
loss = np.mean([o.loss for o in outputs]) | |
acc = np.mean([o.acc for o in outputs]) | |
self.log(f"{prefix}/loss", loss) | |
self.log(f"{prefix}/acc", acc) | |
print(outputs) | |
generated = [g for o in outputs for g in o.generated] | |
references = [r for o in outputs for r in o.descriptions] | |
titles = [r for o in outputs for r in o.titles] | |
# Examples | |
columns = ["Generated", "Reference"] | |
data = list(zip(generated[:256:16], references[:256:16])) | |
table = wandb.Table(data=data, columns=columns) | |
self.logger.experiment.log({f"examples/{prefix}": table}) | |
def validation_epoch_end(self, outputs): | |
self._shared_epoch_end(outputs, "val") | |
def test_epoch_end(self, outputs): | |
self._shared_epoch_end(outputs, "test") | |
def configure_optimizers(self): | |
optimizer = optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.98)) | |
#scheduler = WarmupDecayLR(optimizer, warmup_steps=10000, d_model=self.d_model) | |
return [optimizer] | |