Spaces:
Runtime error
Runtime error
import math | |
from typing import Tuple | |
import torch | |
import torch.nn as nn | |
from torch.nn.modules.transformer import ( | |
TransformerDecoder, | |
TransformerDecoderLayer, | |
TransformerEncoder, | |
TransformerEncoderLayer, | |
) | |
from dataset import Batched, EncodedBatch | |
from vocab import BOS_ID, EOS_ID, PAD_ID | |
import helpers | |
class PositionalEncoding(nn.Module): | |
def __init__(self, dropout, dim, max_len=5000): | |
""" | |
initialization of required variables and functions | |
:param dropout: dropout probability | |
:param dim: hidden size | |
:param max_len: maximum length | |
""" | |
super(PositionalEncoding, self).__init__() | |
# positional encoding initialization | |
pe = torch.zeros(max_len, dim) | |
position = torch.arange(0, max_len).unsqueeze(1) | |
# term to divide | |
div_term = torch.exp( | |
(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim)) | |
) | |
# sinusoidal positional encoding | |
pe[:, 0::2] = torch.sin(position.float() * div_term) | |
pe[:, 1::2] = torch.cos(position.float() * div_term) | |
pe = pe.unsqueeze(1) | |
self.register_buffer("pe", pe) | |
self.dropout = nn.Dropout(p=dropout) | |
self.dim = dim | |
def forward(self, emb): | |
""" | |
create positional encoding | |
:param emb: word embedding | |
:param step: step for decoding in inference | |
:return: positional encoding representation | |
""" | |
emb *= math.sqrt(self.dim) | |
emb = emb + self.pe[: emb.size(0)] # [len, batch, size] | |
emb = self.dropout(emb) | |
return emb | |
class Encoder(nn.Module): | |
def from_args(args) -> "Encoder": | |
return Encoder( | |
args.text_vocab_size + args.cond_vocab_size, | |
args.max_seq_len, | |
args.d_model, | |
args.nhead, | |
args.num_encoder_layers, | |
args.dropout, | |
args.mode, | |
) | |
def __init__( | |
self, | |
vocab_size: int, | |
max_seq_len: int, | |
d_model: int, | |
nhead: int, | |
num_layers: int, | |
dropout: float, | |
mode: str, | |
): | |
super().__init__() | |
self.d_model = d_model | |
self.max_seq_len = max_seq_len | |
self.input_embedding = nn.Embedding(vocab_size, d_model) | |
self.pos_encoder = PositionalEncoding(dropout, d_model) | |
encoder_layer = TransformerEncoderLayer( | |
d_model, nhead, d_model * 4, dropout, norm_first=True | |
) | |
self.encoder = TransformerEncoder( | |
encoder_layer, num_layers, nn.LayerNorm(d_model) | |
) | |
self.mode = mode | |
def device(self): | |
return list(self.parameters())[0].device | |
def forward(self, batched: Batched) -> EncodedBatch: | |
src, src_key_padding_mask = Encoder._get_input(batched, self.mode) | |
src = self.input_embedding(src) | |
src = self.pos_encoder(src) | |
token_encodings = self.encoder.forward( | |
src=src, src_key_padding_mask=src_key_padding_mask | |
) | |
return EncodedBatch( | |
context_encodings=token_encodings, | |
context_encodings_mask=src_key_padding_mask, | |
) | |
def _get_input(batched: Batched, mode: str) -> Tuple[torch.Tensor, torch.Tensor]: | |
return { | |
helpers.BASELINE: (batched.title_token_ids, batched.title_token_ids_mask), | |
helpers.KOBE_ATTRIBUTE: ( | |
batched.cond_title_token_ids, | |
batched.cond_title_token_ids_mask, | |
), | |
helpers.KOBE_KNOWLEDGE: ( | |
batched.title_fact_token_ids, | |
batched.title_fact_token_ids_mask, | |
), | |
helpers.KOBE_FULL: ( | |
batched.cond_title_fact_token_ids, | |
batched.cond_title_fact_token_ids_mask, | |
), | |
}[mode] | |
class Decoder(nn.Module): | |
def from_args(args) -> "Decoder": | |
return Decoder( | |
args.text_vocab_size, | |
args.max_seq_len, | |
args.d_model, | |
args.nhead, | |
args.num_encoder_layers, | |
args.dropout, | |
) | |
def __init__( | |
self, | |
vocab_size: int, | |
max_seq_len: int, | |
d_model: int, | |
nhead: int, | |
num_layers: int, | |
dropout: float, | |
): | |
super(Decoder, self).__init__() | |
self.max_seq_len = max_seq_len | |
self.embedding = nn.Embedding(vocab_size, d_model) | |
self.pos_encoder = PositionalEncoding(dropout, d_model) | |
decoder_layer = TransformerDecoderLayer( | |
d_model, nhead, 4 * d_model, dropout, norm_first=True | |
) | |
self.decoder = TransformerDecoder( | |
decoder_layer, num_layers, nn.LayerNorm(d_model) | |
) | |
self.output = nn.Linear(d_model, vocab_size) | |
def forward(self, batch: Batched, encoded_batch: EncodedBatch) -> torch.Tensor: | |
tgt = self.embedding(batch.description_token_ids[:-1]) | |
tgt = self.pos_encoder(tgt) | |
tgt_mask = Decoder.generate_square_subsequent_mask(tgt.shape[0], tgt.device) | |
outputs = self.decoder( | |
tgt=tgt, | |
tgt_mask=tgt_mask, | |
tgt_key_padding_mask=batch.description_token_ids_mask[:, :-1], | |
memory=encoded_batch.context_encodings, | |
memory_key_padding_mask=encoded_batch.context_encodings_mask, | |
) | |
return self.output(outputs) | |
def predict(self, encoded_batch: EncodedBatch, decoding_strategy: str): | |
batch_size = encoded_batch.context_encodings.shape[1] | |
tgt = torch.tensor( | |
[BOS_ID] * batch_size, device=encoded_batch.context_encodings.device | |
).unsqueeze(dim=0) | |
tgt_mask = Decoder.generate_square_subsequent_mask(self.max_seq_len, tgt.device) | |
pred_all = [] | |
for idx in range(self.max_seq_len): | |
tgt_emb = self.pos_encoder(self.embedding(tgt)) | |
outputs = self.decoder( | |
tgt_emb, | |
tgt_mask=tgt_mask[: idx + 1, : idx + 1], | |
memory=encoded_batch.context_encodings, | |
memory_key_padding_mask=encoded_batch.context_encodings_mask, | |
) | |
logits = self.output(outputs[-1]) | |
if decoding_strategy == "greedy": | |
pred_step = logits.argmax(dim=1).tolist() | |
elif decoding_strategy == "nucleus": | |
pred_step = [ | |
helpers.top_k_top_p_sampling(logits[i], top_p=0.95) | |
for i in range(batch_size) | |
] | |
else: | |
raise NotImplementedError | |
for b in range(batch_size): | |
if pred_all and pred_all[-1][b].item() in [EOS_ID, PAD_ID]: | |
pred_step[b] = PAD_ID | |
if all([pred == PAD_ID for pred in pred_step]): | |
break | |
pred_step = torch.tensor(pred_step, device=tgt.device) | |
pred_all.append(pred_step) | |
if idx < self.max_seq_len - 1: | |
tgt_step = pred_step.unsqueeze(dim=0) | |
tgt = torch.cat([tgt, tgt_step], dim=0) | |
preds = torch.stack(pred_all) | |
return preds | |
def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor: | |
r""" | |
Generate a square mask for the sequence. The masked positions are filled with | |
float('-inf'). | |
Unmasked positions are filled with float(0.0). | |
""" | |
return torch.triu( | |
torch.full((sz, sz), float("-inf"), device=device), diagonal=1 | |
) | |