|
|
|
|
|
|
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence |
|
|
|
|
|
class CharEmbedding(nn.Module): |
|
def __init__(self, vocab_size: int, embedding_size: int, output_size: int): |
|
super(CharEmbedding, self).__init__() |
|
|
|
self.embedding = nn.Embedding(vocab_size, embedding_size, sparse=False) |
|
self.layer_norm = nn.LayerNorm(embedding_size) |
|
self.gru = nn.GRU(embedding_size, embedding_size, num_layers=1, bidirectional=True) |
|
self.out_linear = nn.Linear(2*embedding_size, output_size) |
|
self.layer_norm_2 = nn.LayerNorm(output_size) |
|
|
|
def forward(self, words, sentence_lens, word_lens): |
|
|
|
n_words = words.size(1) |
|
sentence_lens = sentence_lens.cpu() |
|
sentence_packed = pack_padded_sequence(words, sentence_lens, batch_first=True) |
|
lens_packed = pack_padded_sequence(word_lens, sentence_lens, batch_first=True) |
|
word_packed = pack_padded_sequence(sentence_packed.data, lens_packed.data.cpu(), batch_first=True, enforce_sorted=False) |
|
|
|
embedded = self.embedding(word_packed.data) |
|
embedded = self.layer_norm(embedded) |
|
|
|
embedded_packed = PackedSequence(embedded, word_packed[1], word_packed[2], word_packed[3]) |
|
_, embedded = self.gru(embedded_packed) |
|
|
|
embedded = embedded[-2:, :, :].transpose(0, 1).flatten(1, 2) |
|
embedded = F.relu(embedded) |
|
embedded = self.out_linear(embedded) |
|
embedded = self.layer_norm_2(embedded) |
|
|
|
embedded, _ = pad_packed_sequence( |
|
PackedSequence(embedded, sentence_packed[1], sentence_packed[2], sentence_packed[3]), batch_first=True, total_length=n_words, |
|
) |
|
|
|
return embedded |
|
|