mhg-parsing / MHGTagger /RNNTagger.py
nielklug's picture
init
6ed21b9
import sys
import torch
from torch import nn
class WordRepresentation(nn.Module):
'''
RNN for computing character-based word representations
'''
def __init__(self, num_chars, emb_size, rec_size, dropout_rate):
super().__init__()
# character embedding lookup table
self.embeddings = nn.Embedding(num_chars, emb_size)
# character-based LSTMs
self.fwd_rnn = nn.LSTM(emb_size, rec_size)
self.bwd_rnn = nn.LSTM(emb_size, rec_size)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, fwd_charIDs, bwd_charIDs):
# swap the 2 dimensions and lookup the embeddings
fwd_embs = self.embeddings(fwd_charIDs.t())
bwd_embs = self.embeddings(bwd_charIDs.t())
# run the biLSTM over characters
fwd_outputs, _ = self.fwd_rnn(fwd_embs)
bwd_outputs, _ = self.bwd_rnn(bwd_embs)
# concatenate the forward and backward final states to form
# word representations
word_reprs = torch.cat((fwd_outputs[-1], bwd_outputs[-1]), -1)
return word_reprs
class ResidualLSTM(nn.Module):
''' Deep BiRNN with residual connections '''
def __init__(self, input_size, rec_size, num_rnns, dropout_rate):
super().__init__()
self.rnn = nn.LSTM(input_size, rec_size,
bidirectional=True, batch_first=True)
self.deep_rnns = nn.ModuleList([
nn.LSTM(2*rec_size, rec_size, bidirectional=True, batch_first=True)
for _ in range(num_rnns-1)])
self.dropout = nn.Dropout(dropout_rate)
def forward(self, state):
state, _ = self.rnn(state)
for rnn in self.deep_rnns:
hidden, _ = rnn(self.dropout(state))
state = state + hidden # residual connection
return state
class RNNTagger(nn.Module):
''' main tagger module '''
def __init__(self, num_chars, num_tags, char_emb_size, char_rec_size,
word_rec_size, word_rnn_depth, dropout_rate, word_emb_size):
super().__init__()
# character-based BiLSTMs
self.word_representations = WordRepresentation(num_chars, char_emb_size,
char_rec_size, dropout_rate)
# word-based BiLSTM
self.word_rnn = ResidualLSTM(char_rec_size*2, word_rec_size, word_rnn_depth,
dropout_rate)
# output feed-forward network
self.output_layer = nn.Linear(2*word_rec_size, num_tags)
# dropout layers
self.dropout = nn.Dropout(dropout_rate)
# word embedding projection layer for finetuning on word embeddings
if word_emb_size > 0:
self.projection_layer = nn.Linear(2*char_rec_size, word_emb_size)
def forward(self, fwd_charIDs, bwd_charIDs, word_embedding_training=False):
# compute the character-based word representations
word_reprs = self.word_representations(fwd_charIDs, bwd_charIDs)
if word_embedding_training:
if not hasattr(self, 'projection_layer'):
sys.exit("Error: The embedding projection layer is undefined!")
# Project the word representations to word embedding vectors
# for finetuning on word embeddings as an auxiliary task
word_embs = self.projection_layer(word_reprs)
return word_embs
# apply dropout
word_reprs = self.dropout(word_reprs)
# run the BiLSTM over words
reprs = self.word_rnn(word_reprs.unsqueeze(0)).squeeze(0)
reprs = self.dropout(reprs) # and apply dropout
# apply the output layers
scores = self.output_layer(reprs)
return scores