Spaces:
Runtime error
Runtime error
"""Length bonus module.""" | |
from typing import Any | |
from typing import List | |
from typing import Tuple | |
import torch | |
from espnet.nets.scorer_interface import BatchScorerInterface | |
class LengthBonus(BatchScorerInterface): | |
"""Length bonus in beam search.""" | |
def __init__(self, n_vocab: int): | |
"""Initialize class. | |
Args: | |
n_vocab (int): The number of tokens in vocabulary for beam search | |
""" | |
self.n = n_vocab | |
def score(self, y, state, x): | |
"""Score new token. | |
Args: | |
y (torch.Tensor): 1D torch.int64 prefix tokens. | |
state: Scorer state for prefix tokens | |
x (torch.Tensor): 2D encoder feature that generates ys. | |
Returns: | |
tuple[torch.Tensor, Any]: Tuple of | |
torch.float32 scores for next token (n_vocab) | |
and None | |
""" | |
return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None | |
def batch_score( | |
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor | |
) -> Tuple[torch.Tensor, List[Any]]: | |
"""Score new token batch. | |
Args: | |
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). | |
states (List[Any]): Scorer states for prefix tokens. | |
xs (torch.Tensor): | |
The encoder feature that generates ys (n_batch, xlen, n_feat). | |
Returns: | |
tuple[torch.Tensor, List[Any]]: Tuple of | |
batchfied scores for next token with shape of `(n_batch, n_vocab)` | |
and next state list for ys. | |
""" | |
return ( | |
torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand( | |
ys.shape[0], self.n | |
), | |
None, | |
) | |