tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
9.55 kB
#!/usr/bin/env python3
# Copyright 2018 Mitsubishi Electric Research Laboratories (Takaaki Hori)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from espnet.lm.lm_utils import make_lexical_tree
from espnet.nets.pytorch_backend.nets_utils import to_device
# Definition of a multi-level (subword/word) language model
class MultiLevelLM(nn.Module):
logzero = -10000000000.0
zero = 1.0e-10
def __init__(
self,
wordlm,
subwordlm,
word_dict,
subword_dict,
subwordlm_weight=0.8,
oov_penalty=1.0,
open_vocab=True,
):
super(MultiLevelLM, self).__init__()
self.wordlm = wordlm
self.subwordlm = subwordlm
self.word_eos = word_dict["<eos>"]
self.word_unk = word_dict["<unk>"]
self.var_word_eos = torch.LongTensor([self.word_eos])
self.var_word_unk = torch.LongTensor([self.word_unk])
self.space = subword_dict["<space>"]
self.eos = subword_dict["<eos>"]
self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk)
self.log_oov_penalty = math.log(oov_penalty)
self.open_vocab = open_vocab
self.subword_dict_size = len(subword_dict)
self.subwordlm_weight = subwordlm_weight
self.normalized = True
def forward(self, state, x):
# update state with input label x
if state is None: # make initial states and log-prob vectors
self.var_word_eos = to_device(x, self.var_word_eos)
self.var_word_unk = to_device(x, self.var_word_eos)
wlm_state, z_wlm = self.wordlm(None, self.var_word_eos)
wlm_logprobs = F.log_softmax(z_wlm, dim=1)
clm_state, z_clm = self.subwordlm(None, x)
log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight
new_node = self.lexroot
clm_logprob = 0.0
xi = self.space
else:
clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state
xi = int(x)
if xi == self.space: # inter-word transition
if node is not None and node[1] >= 0: # check if the node is word end
w = to_device(x, torch.LongTensor([node[1]]))
else: # this node is not a word end, which means <unk>
w = self.var_word_unk
# update wordlm state and log-prob vector
wlm_state, z_wlm = self.wordlm(wlm_state, w)
wlm_logprobs = F.log_softmax(z_wlm, dim=1)
new_node = self.lexroot # move to the tree root
clm_logprob = 0.0
elif node is not None and xi in node[0]: # intra-word transition
new_node = node[0][xi]
clm_logprob += log_y[0, xi]
elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode
new_node = None
clm_logprob += log_y[0, xi]
else: # if open_vocab flag is disabled, return 0 probabilities
log_y = to_device(
x, torch.full((1, self.subword_dict_size), self.logzero)
)
return (clm_state, wlm_state, wlm_logprobs, None, log_y, 0.0), log_y
clm_state, z_clm = self.subwordlm(clm_state, x)
log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight
# apply word-level probabilies for <space> and <eos> labels
if xi != self.space:
if new_node is not None and new_node[1] >= 0: # if new node is word end
wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob
else:
wlm_logprob = wlm_logprobs[:, self.word_unk] + self.log_oov_penalty
log_y[:, self.space] = wlm_logprob
log_y[:, self.eos] = wlm_logprob
else:
log_y[:, self.space] = self.logzero
log_y[:, self.eos] = self.logzero
return (
(clm_state, wlm_state, wlm_logprobs, new_node, log_y, float(clm_logprob)),
log_y,
)
def final(self, state):
clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state
if node is not None and node[1] >= 0: # check if the node is word end
w = to_device(wlm_logprobs, torch.LongTensor([node[1]]))
else: # this node is not a word end, which means <unk>
w = self.var_word_unk
wlm_state, z_wlm = self.wordlm(wlm_state, w)
return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos])
# Definition of a look-ahead word language model
class LookAheadWordLM(nn.Module):
logzero = -10000000000.0
zero = 1.0e-10
def __init__(
self, wordlm, word_dict, subword_dict, oov_penalty=0.0001, open_vocab=True
):
super(LookAheadWordLM, self).__init__()
self.wordlm = wordlm
self.word_eos = word_dict["<eos>"]
self.word_unk = word_dict["<unk>"]
self.var_word_eos = torch.LongTensor([self.word_eos])
self.var_word_unk = torch.LongTensor([self.word_unk])
self.space = subword_dict["<space>"]
self.eos = subword_dict["<eos>"]
self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk)
self.oov_penalty = oov_penalty
self.open_vocab = open_vocab
self.subword_dict_size = len(subword_dict)
self.zero_tensor = torch.FloatTensor([self.zero])
self.normalized = True
def forward(self, state, x):
# update state with input label x
if state is None: # make initial states and cumlative probability vector
self.var_word_eos = to_device(x, self.var_word_eos)
self.var_word_unk = to_device(x, self.var_word_eos)
self.zero_tensor = to_device(x, self.zero_tensor)
wlm_state, z_wlm = self.wordlm(None, self.var_word_eos)
cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1)
new_node = self.lexroot
xi = self.space
else:
wlm_state, cumsum_probs, node = state
xi = int(x)
if xi == self.space: # inter-word transition
if node is not None and node[1] >= 0: # check if the node is word end
w = to_device(x, torch.LongTensor([node[1]]))
else: # this node is not a word end, which means <unk>
w = self.var_word_unk
# update wordlm state and cumlative probability vector
wlm_state, z_wlm = self.wordlm(wlm_state, w)
cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1)
new_node = self.lexroot # move to the tree root
elif node is not None and xi in node[0]: # intra-word transition
new_node = node[0][xi]
elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode
new_node = None
else: # if open_vocab flag is disabled, return 0 probabilities
log_y = to_device(
x, torch.full((1, self.subword_dict_size), self.logzero)
)
return (wlm_state, None, None), log_y
if new_node is not None:
succ, wid, wids = new_node
# compute parent node probability
sum_prob = (
(cumsum_probs[:, wids[1]] - cumsum_probs[:, wids[0]])
if wids is not None
else 1.0
)
if sum_prob < self.zero:
log_y = to_device(
x, torch.full((1, self.subword_dict_size), self.logzero)
)
return (wlm_state, cumsum_probs, new_node), log_y
# set <unk> probability as a default value
unk_prob = (
cumsum_probs[:, self.word_unk] - cumsum_probs[:, self.word_unk - 1]
)
y = to_device(
x,
torch.full(
(1, self.subword_dict_size), float(unk_prob) * self.oov_penalty
),
)
# compute transition probabilities to child nodes
for cid, nd in succ.items():
y[:, cid] = (
cumsum_probs[:, nd[2][1]] - cumsum_probs[:, nd[2][0]]
) / sum_prob
# apply word-level probabilies for <space> and <eos> labels
if wid >= 0:
wlm_prob = (cumsum_probs[:, wid] - cumsum_probs[:, wid - 1]) / sum_prob
y[:, self.space] = wlm_prob
y[:, self.eos] = wlm_prob
elif xi == self.space:
y[:, self.space] = self.zero
y[:, self.eos] = self.zero
log_y = torch.log(torch.max(y, self.zero_tensor)) # clip to avoid log(0)
else: # if no path in the tree, transition probability is one
log_y = to_device(x, torch.zeros(1, self.subword_dict_size))
return (wlm_state, cumsum_probs, new_node), log_y
def final(self, state):
wlm_state, cumsum_probs, node = state
if node is not None and node[1] >= 0: # check if the node is word end
w = to_device(cumsum_probs, torch.LongTensor([node[1]]))
else: # this node is not a word end, which means <unk>
w = self.var_word_unk
wlm_state, z_wlm = self.wordlm(wlm_state, w)
return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos])