tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
3.12 kB
"""Ngram lm implement."""
from abc import ABC
import kenlm
import torch
from espnet.nets.scorer_interface import BatchScorerInterface
from espnet.nets.scorer_interface import PartialScorerInterface
class Ngrambase(ABC):
"""Ngram base implemented throught ScorerInterface."""
def __init__(self, ngram_model, token_list):
"""Initialize Ngrambase.
Args:
ngram_model: ngram model path
token_list: token list from dict or model.json
"""
self.chardict = [x if x != "<eos>" else "</s>" for x in token_list]
self.charlen = len(self.chardict)
self.lm = kenlm.LanguageModel(ngram_model)
self.tmpkenlmstate = kenlm.State()
def init_state(self, x):
"""Initialize tmp state."""
state = kenlm.State()
self.lm.NullContextWrite(state)
return state
def score_partial_(self, y, next_token, state, x):
"""Score interface for both full and partial scorer.
Args:
y: previous char
next_token: next token need to be score
state: previous state
x: encoded feature
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
out_state = kenlm.State()
ys = self.chardict[y[-1]] if y.shape[0] > 1 else "<s>"
self.lm.BaseScore(state, ys, out_state)
scores = torch.empty_like(next_token, dtype=x.dtype, device=y.device)
for i, j in enumerate(next_token):
scores[i] = self.lm.BaseScore(
out_state, self.chardict[j], self.tmpkenlmstate
)
return scores, out_state
class NgramFullScorer(Ngrambase, BatchScorerInterface):
"""Fullscorer for ngram."""
def score(self, y, state, x):
"""Score interface for both full and partial scorer.
Args:
y: previous char
state: previous state
x: encoded feature
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
return self.score_partial_(y, torch.tensor(range(self.charlen)), state, x)
class NgramPartScorer(Ngrambase, PartialScorerInterface):
"""Partialscorer for ngram."""
def score_partial(self, y, next_token, state, x):
"""Score interface for both full and partial scorer.
Args:
y: previous char
next_token: next token need to be score
state: previous state
x: encoded feature
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
return self.score_partial_(y, next_token, state, x)
def select_state(self, state, i):
"""Empty select state for scorer interface."""
return state