Spaces:
Runtime error
Runtime error
File size: 7,676 Bytes
6aee98f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import math
from typing import Tuple
import torch
import torch.nn as nn
from cached_property import cached_property
from torch.nn.modules.transformer import (
TransformerDecoder,
TransformerDecoderLayer,
TransformerEncoder,
TransformerEncoderLayer,
)
from dataset import Batched, EncodedBatch
from vocab import BOS_ID, EOS_ID, PAD_ID
import helper
class PositionalEncoding(nn.Module):
def __init__(self, dropout, dim, max_len=5000):
"""
initialization of required variables and functions
:param dropout: dropout probability
:param dim: hidden size
:param max_len: maximum length
"""
super(PositionalEncoding, self).__init__()
# positional encoding initialization
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len).unsqueeze(1)
# term to divide
div_term = torch.exp(
(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
)
# sinusoidal positional encoding
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(1)
self.register_buffer("pe", pe)
self.dropout = nn.Dropout(p=dropout)
self.dim = dim
def forward(self, emb):
"""
create positional encoding
:param emb: word embedding
:param step: step for decoding in inference
:return: positional encoding representation
"""
emb *= math.sqrt(self.dim)
emb = emb + self.pe[: emb.size(0)] # [len, batch, size]
emb = self.dropout(emb)
return emb
class Encoder(nn.Module):
@staticmethod
def from_args(args) -> "Encoder":
return Encoder(
args.text_vocab_size + args.cond_vocab_size,
args.max_seq_len,
args.d_model,
args.nhead,
args.num_encoder_layers,
args.dropout,
args.mode,
)
def __init__(
self,
vocab_size: int,
max_seq_len: int,
d_model: int,
nhead: int,
num_layers: int,
dropout: float,
mode: str,
):
super().__init__()
self.d_model = d_model
self.max_seq_len = max_seq_len
self.input_embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(dropout, d_model)
encoder_layer = TransformerEncoderLayer(
d_model, nhead, d_model * 4, dropout, norm_first=True
)
self.encoder = TransformerEncoder(
encoder_layer, num_layers, nn.LayerNorm(d_model)
)
self.mode = mode
@cached_property
def device(self):
return list(self.parameters())[0].device
def forward(self, batched: Batched) -> EncodedBatch:
src, src_key_padding_mask = Encoder._get_input(batched, self.mode)
src = self.input_embedding(src)
src = self.pos_encoder(src)
token_encodings = self.encoder.forward(
src=src, src_key_padding_mask=src_key_padding_mask
)
return EncodedBatch(
context_encodings=token_encodings,
context_encodings_mask=src_key_padding_mask,
)
@staticmethod
def _get_input(batched: Batched, mode: str) -> Tuple[torch.Tensor, torch.Tensor]:
return {
helpers.BASELINE: (batched.title_token_ids, batched.title_token_ids_mask),
helpers.KOBE_ATTRIBUTE: (
batched.cond_title_token_ids,
batched.cond_title_token_ids_mask,
),
helpers.KOBE_KNOWLEDGE: (
batched.title_fact_token_ids,
batched.title_fact_token_ids_mask,
),
helpers.KOBE_FULL: (
batched.cond_title_fact_token_ids,
batched.cond_title_fact_token_ids_mask,
),
}[mode]
class Decoder(nn.Module):
@staticmethod
def from_args(args) -> "Decoder":
return Decoder(
args.text_vocab_size,
args.max_seq_len,
args.d_model,
args.nhead,
args.num_encoder_layers,
args.dropout,
)
def __init__(
self,
vocab_size: int,
max_seq_len: int,
d_model: int,
nhead: int,
num_layers: int,
dropout: float,
):
super(Decoder, self).__init__()
self.max_seq_len = max_seq_len
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(dropout, d_model)
decoder_layer = TransformerDecoderLayer(
d_model, nhead, 4 * d_model, dropout, norm_first=True
)
self.decoder = TransformerDecoder(
decoder_layer, num_layers, nn.LayerNorm(d_model)
)
self.output = nn.Linear(d_model, vocab_size)
def forward(self, batch: Batched, encoded_batch: EncodedBatch) -> torch.Tensor:
tgt = self.embedding(batch.description_token_ids[:-1])
tgt = self.pos_encoder(tgt)
tgt_mask = Decoder.generate_square_subsequent_mask(tgt.shape[0], tgt.device)
outputs = self.decoder(
tgt=tgt,
tgt_mask=tgt_mask,
tgt_key_padding_mask=batch.description_token_ids_mask[:, :-1],
memory=encoded_batch.context_encodings,
memory_key_padding_mask=encoded_batch.context_encodings_mask,
)
return self.output(outputs)
def predict(self, encoded_batch: EncodedBatch, decoding_strategy: str):
batch_size = encoded_batch.context_encodings.shape[1]
tgt = torch.tensor(
[BOS_ID] * batch_size, device=encoded_batch.context_encodings.device
).unsqueeze(dim=0)
tgt_mask = Decoder.generate_square_subsequent_mask(self.max_seq_len, tgt.device)
pred_all = []
for idx in range(self.max_seq_len):
tgt_emb = self.pos_encoder(self.embedding(tgt))
outputs = self.decoder(
tgt_emb,
tgt_mask=tgt_mask[: idx + 1, : idx + 1],
memory=encoded_batch.context_encodings,
memory_key_padding_mask=encoded_batch.context_encodings_mask,
)
logits = self.output(outputs[-1])
if decoding_strategy == "greedy":
pred_step = logits.argmax(dim=1).tolist()
elif decoding_strategy == "nucleus":
pred_step = [
helpers.top_k_top_p_sampling(logits[i], top_p=0.95)
for i in range(batch_size)
]
else:
raise NotImplementedError
for b in range(batch_size):
if pred_all and pred_all[-1][b].item() in [EOS_ID, PAD_ID]:
pred_step[b] = PAD_ID
if all([pred == PAD_ID for pred in pred_step]):
break
pred_step = torch.tensor(pred_step, device=tgt.device)
pred_all.append(pred_step)
if idx < self.max_seq_len - 1:
tgt_step = pred_step.unsqueeze(dim=0)
tgt = torch.cat([tgt, tgt_step], dim=0)
preds = torch.stack(pred_all)
return preds
@staticmethod
def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
r"""
Generate a square mask for the sequence. The masked positions are filled with
float('-inf').
Unmasked positions are filled with float(0.0).
"""
return torch.triu(
torch.full((sz, sz), float("-inf"), device=device), diagonal=1
)
|