Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from fairseq.search import Search | |
class NoisyChannelBeamSearch(Search): | |
def __init__(self, tgt_dict): | |
super().__init__(tgt_dict) | |
self.fw_scores_buf = None | |
self.lm_scores_buf = None | |
def _init_buffers(self, t): | |
# super()._init_buffers(t) | |
if self.fw_scores_buf is None: | |
self.scores_buf = t.new() | |
self.indices_buf = torch.LongTensor().to(device=t.device) | |
self.beams_buf = torch.LongTensor().to(device=t.device) | |
self.fw_scores_buf = t.new() | |
self.lm_scores_buf = t.new() | |
def combine_fw_bw(self, combine_method, fw_cum, bw, step): | |
if combine_method == "noisy_channel": | |
fw_norm = fw_cum.div(step + 1) | |
lprobs = bw + fw_norm | |
elif combine_method == "lm_only": | |
lprobs = bw + fw_cum | |
return lprobs | |
def step(self, step, fw_lprobs, scores, bw_lprobs, lm_lprobs, combine_method): | |
self._init_buffers(fw_lprobs) | |
bsz, beam_size, vocab_size = fw_lprobs.size() | |
if step == 0: | |
# at the first step all hypotheses are equally likely, so use | |
# only the first beam | |
fw_lprobs = fw_lprobs[:, ::beam_size, :].contiguous() | |
bw_lprobs = bw_lprobs[:, ::beam_size, :].contiguous() | |
# nothing to add since we are at the first step | |
fw_lprobs_cum = fw_lprobs | |
else: | |
# make probs contain cumulative scores for each hypothesis | |
raw_scores = (scores[:, :, step - 1].unsqueeze(-1)) | |
fw_lprobs_cum = (fw_lprobs.add(raw_scores)) | |
combined_lprobs = self.combine_fw_bw(combine_method, fw_lprobs_cum, bw_lprobs, step) | |
# choose the top k according to the combined noisy channel model score | |
torch.topk( | |
combined_lprobs.view(bsz, -1), | |
k=min( | |
# Take the best 2 x beam_size predictions. We'll choose the first | |
# beam_size of these which don't predict eos to continue with. | |
beam_size * 2, | |
combined_lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad | |
), | |
out=(self.scores_buf, self.indices_buf), | |
) | |
# save corresponding fw and lm scores | |
self.fw_scores_buf = torch.gather(fw_lprobs_cum.view(bsz, -1), 1, self.indices_buf) | |
self.lm_scores_buf = torch.gather(lm_lprobs.view(bsz, -1), 1, self.indices_buf) | |
# Project back into relative indices and beams | |
self.beams_buf = self.indices_buf // vocab_size | |
self.indices_buf.fmod_(vocab_size) | |
return self.scores_buf, self.fw_scores_buf, self.lm_scores_buf, self.indices_buf, self.beams_buf | |