Spaces:
Build error
Build error
File size: 3,438 Bytes
daf0288 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import torch
from torch import Tensor, nn
from functools import partial
from .components import (
ImgCnnBackbone,
ImgLinearBackbone,
ImgConvStemBackbone,
Encoder,
Decoder,
PositionEmbedding,
TokenEmbedding,
)
class EncoderDecoder(nn.Module):
"""Encoder decoder architecture that takes in a tabular image and generates the text output.
Backbone serves as the image processor. There are three types of backbones: CNN, linear projection, and ConvStem.
Args:
----
backbone: tabular image processor
encoder: transformer encoder
decoder: transformer decoder
vocab_size: size of the vocabulary
d_model: feature size
padding_idx: index of <pad> in the vocabulary
max_seq_len: max sequence length of generated text
dropout: dropout rate
norm_layer: layernorm
init_std: std in weights initialization
"""
def __init__(
self,
backbone: nn.Module,
encoder: nn.Module,
decoder: nn.Module,
vocab_size: int,
d_model: int,
padding_idx: int,
max_seq_len: int,
dropout: float,
norm_layer: nn.Module,
init_std: float = 0.02,
):
super().__init__()
self.backbone = backbone
self.encoder = encoder
self.decoder = decoder
self.norm = norm_layer(d_model)
self.token_embed = TokenEmbedding(
vocab_size=vocab_size, d_model=d_model, padding_idx=padding_idx
)
self.pos_embed = PositionEmbedding(
max_seq_len=max_seq_len, d_model=d_model, dropout=dropout
)
self.generator = nn.Linear(d_model, vocab_size)
self.trunc_normal = partial(
nn.init.trunc_normal_, std=init_std, a=-init_std, b=init_std
)
self.apply(self._init_weights)
def _init_weights(self, m: nn.Module):
if isinstance(m, nn.Linear):
self.trunc_normal(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1.0)
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.Conv2d):
self.trunc_normal(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, PositionEmbedding):
self.trunc_normal(m.embedding.weight)
elif isinstance(m, TokenEmbedding):
self.trunc_normal(m.embedding.weight)
@torch.jit.ignore
def no_weight_decay(self):
return {"token_embed", "pos_embed"}
def encode(self, src: Tensor) -> Tensor:
src_feature = self.backbone(src)
src_feature = self.pos_embed(src_feature)
memory = self.encoder(src_feature)
memory = self.norm(memory)
return memory
def decode(
self, memory: Tensor, tgt: Tensor, tgt_mask: Tensor, tgt_padding_mask: Tensor
) -> Tensor:
tgt_feature = self.pos_embed(self.token_embed(tgt))
tgt = self.decoder(tgt_feature, memory, tgt_mask, tgt_padding_mask)
return tgt
def forward(
self, src: Tensor, tgt: Tensor, tgt_mask: Tensor, tgt_padding_mask: Tensor
) -> Tensor:
memory = self.encode(src)
tgt = self.decode(memory, tgt, tgt_mask, tgt_padding_mask)
tgt = self.generator(tgt)
return tgt
|