Spaces:
Sleeping
Sleeping
import sys | |
import torch | |
from torch import nn | |
class WordRepresentation(nn.Module): | |
''' | |
RNN for computing character-based word representations | |
''' | |
def __init__(self, num_chars, emb_size, rec_size, dropout_rate): | |
super().__init__() | |
# character embedding lookup table | |
self.embeddings = nn.Embedding(num_chars, emb_size) | |
# character-based LSTMs | |
self.fwd_rnn = nn.LSTM(emb_size, rec_size) | |
self.bwd_rnn = nn.LSTM(emb_size, rec_size) | |
self.dropout = nn.Dropout(dropout_rate) | |
def forward(self, fwd_charIDs, bwd_charIDs): | |
# swap the 2 dimensions and lookup the embeddings | |
fwd_embs = self.embeddings(fwd_charIDs.t()) | |
bwd_embs = self.embeddings(bwd_charIDs.t()) | |
# run the biLSTM over characters | |
fwd_outputs, _ = self.fwd_rnn(fwd_embs) | |
bwd_outputs, _ = self.bwd_rnn(bwd_embs) | |
# concatenate the forward and backward final states to form | |
# word representations | |
word_reprs = torch.cat((fwd_outputs[-1], bwd_outputs[-1]), -1) | |
return word_reprs | |
class ResidualLSTM(nn.Module): | |
''' Deep BiRNN with residual connections ''' | |
def __init__(self, input_size, rec_size, num_rnns, dropout_rate): | |
super().__init__() | |
self.rnn = nn.LSTM(input_size, rec_size, | |
bidirectional=True, batch_first=True) | |
self.deep_rnns = nn.ModuleList([ | |
nn.LSTM(2*rec_size, rec_size, bidirectional=True, batch_first=True) | |
for _ in range(num_rnns-1)]) | |
self.dropout = nn.Dropout(dropout_rate) | |
def forward(self, state): | |
state, _ = self.rnn(state) | |
for rnn in self.deep_rnns: | |
hidden, _ = rnn(self.dropout(state)) | |
state = state + hidden # residual connection | |
return state | |
class RNNTagger(nn.Module): | |
''' main tagger module ''' | |
def __init__(self, num_chars, num_tags, char_emb_size, char_rec_size, | |
word_rec_size, word_rnn_depth, dropout_rate, word_emb_size): | |
super().__init__() | |
# character-based BiLSTMs | |
self.word_representations = WordRepresentation(num_chars, char_emb_size, | |
char_rec_size, dropout_rate) | |
# word-based BiLSTM | |
self.word_rnn = ResidualLSTM(char_rec_size*2, word_rec_size, word_rnn_depth, | |
dropout_rate) | |
# output feed-forward network | |
self.output_layer = nn.Linear(2*word_rec_size, num_tags) | |
# dropout layers | |
self.dropout = nn.Dropout(dropout_rate) | |
# word embedding projection layer for finetuning on word embeddings | |
if word_emb_size > 0: | |
self.projection_layer = nn.Linear(2*char_rec_size, word_emb_size) | |
def forward(self, fwd_charIDs, bwd_charIDs, word_embedding_training=False): | |
# compute the character-based word representations | |
word_reprs = self.word_representations(fwd_charIDs, bwd_charIDs) | |
if word_embedding_training: | |
if not hasattr(self, 'projection_layer'): | |
sys.exit("Error: The embedding projection layer is undefined!") | |
# Project the word representations to word embedding vectors | |
# for finetuning on word embeddings as an auxiliary task | |
word_embs = self.projection_layer(word_reprs) | |
return word_embs | |
# apply dropout | |
word_reprs = self.dropout(word_reprs) | |
# run the BiLSTM over words | |
reprs = self.word_rnn(word_reprs.unsqueeze(0)).squeeze(0) | |
reprs = self.dropout(reprs) # and apply dropout | |
# apply the output layers | |
scores = self.output_layer(reprs) | |
return scores | |