File size: 14,997 Bytes
c0f8a54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 |
# pylint: skip-file
# Copied from: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py
import torch
from transformers import LogitsProcessor
class CTCPrefixScoreTH(object):
"""Batch processing of CTCPrefixScore
which is based on Algorithm 2 in WATANABE et al.
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
but extended to efficiently compute the label probablities for multiple
hypotheses simultaneously
See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based
Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019.
"""
def __init__(self, x, xlens, blank, eos, margin=0):
"""Construct CTC prefix scorer
:param torch.Tensor x: input label posterior sequences (B, T, O)
:param torch.Tensor xlens: input lengths (B,)
:param int blank: blank label id
:param int eos: end-of-sequence id
:param int margin: margin parameter for windowing (0 means no windowing)
"""
# In the comment lines,
# we assume T: input_length, B: batch size, W: beam width, O: output dim.
self.logzero = -10000000000.0
self.blank = blank
self.eos = eos
self.batch = x.size(0)
self.input_length = x.size(1)
self.odim = x.size(2)
self.dtype = x.dtype
self.device = torch.device("cuda:%d" % x.get_device()) if x.is_cuda else torch.device("cpu")
# Pad the rest of posteriors in the batch
# TODO(takaaki-hori): need a better way without for-loops
for i, l in enumerate(xlens):
if l < self.input_length:
x[i, l:, :] = self.logzero
x[i, l:, blank] = 0
# Reshape input x
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
self.x = torch.stack([xn, xb]) # (2, T, B, O)
self.end_frames = torch.as_tensor(xlens) - 1
# Setup CTC windowing
self.margin = margin
if margin > 0:
self.frame_ids = torch.arange(self.input_length, dtype=self.dtype, device=self.device)
# Base indices for index conversion
self.idx_bh = None
self.idx_b = torch.arange(self.batch, device=self.device)
self.idx_bo = (self.idx_b * self.odim).unsqueeze(1)
def __call__(self, y, state, scoring_ids=None, att_w=None):
"""Compute CTC prefix scores for next labels
:param list y: prefix label sequences
:param tuple state: previous CTC state
:param torch.Tensor att_w: attention weights to decide CTC window
:return new_state, ctc_local_scores (BW, O)
"""
# print(self.tokenizer.batch_decode(y))
output_length = len(y[0]) - 1 # ignore sos
last_ids = [yi[-1] for yi in y] # last output label ids
n_bh = len(last_ids) # batch * hyps
n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps
self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0
# prepare state info
if state is None:
r_prev = torch.full(
(self.input_length, 2, self.batch, n_hyps),
self.logzero,
dtype=self.dtype,
device=self.device,
)
r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2)
r_prev = r_prev.view(-1, 2, n_bh)
s_prev = 0.0
f_min_prev = 0
f_max_prev = 1
else:
r_prev, s_prev, f_min_prev, f_max_prev = state
# select input dimensions for decred_scoring
if self.scoring_num > 0:
scoring_idmap = torch.full((n_bh, self.odim), -1, dtype=torch.long, device=self.device)
snum = self.scoring_num
if self.idx_bh is None or n_bh > len(self.idx_bh):
self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1)
scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange(snum, device=self.device)
scoring_idx = (scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1)).view(-1)
x_ = torch.index_select(self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx).view(2, -1, n_bh, snum)
else:
scoring_ids = None
scoring_idmap = None
snum = self.odim
x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum)
# new CTC forward probs are prepared as a (T x 2 x BW x S) tensor
# that corresponds to r_t^n(h) and r_t^b(h) in a batch.
r = torch.full(
(self.input_length, 2, n_bh, snum),
self.logzero,
dtype=self.dtype,
device=self.device,
)
if output_length == 0:
r[0, 0] = x_[0, 0]
r_sum = torch.logsumexp(r_prev, 1)
log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum)
if scoring_ids is not None:
for idx in range(n_bh):
pos = scoring_idmap[idx, last_ids[idx]]
if pos >= 0:
log_phi[:, idx, pos] = r_prev[:, 1, idx]
else:
for idx in range(n_bh):
log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx]
# decide start and end frames based on attention weights
if att_w is not None and self.margin > 0:
f_arg = torch.matmul(att_w, self.frame_ids)
f_min = max(int(f_arg.min().cpu()), f_min_prev)
f_max = max(int(f_arg.max().cpu()), f_max_prev)
start = min(f_max_prev, max(f_min - self.margin, output_length, 1))
end = min(f_max + self.margin, self.input_length)
else:
f_min = f_max = 0
start = max(output_length, 1)
end = self.input_length
if start > end:
return torch.full_like(s_prev, self.logzero), (
r,
torch.full_like(s_prev, self.logzero),
f_min,
f_max,
scoring_idmap,
)
# compute forward probabilities log(r_t^n(h)) and log(r_t^b(h))
for t in range(start, end):
rp = r[t - 1]
rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view(2, 2, n_bh, snum)
r[t] = torch.logsumexp(rr, 1) + x_[:, t]
# compute log prefix probabilities log(psi)
log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0]
if scoring_ids is not None:
log_psi = torch.full((n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device)
log_psi_ = torch.logsumexp(
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
dim=0,
)
for si in range(n_bh):
log_psi[si, scoring_ids[si]] = log_psi_[si]
else:
log_psi = torch.logsumexp(
torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0),
dim=0,
)
# for si in range(n_bh):
# log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si]
# exclude blank probs
log_psi[:, self.blank] = self.logzero
token_scores = log_psi - s_prev
token_scores[token_scores == 0] = self.logzero
return token_scores, (r, log_psi, f_min, f_max, scoring_idmap)
def index_select_state(self, state, best_ids):
"""Select CTC states according to best ids
:param state : CTC state
:param best_ids : index numbers selected by beam pruning (B, W)
:return selected_state
"""
r, s, f_min, f_max, scoring_idmap = state
# convert ids to BHO space
n_bh = len(s)
n_hyps = n_bh // self.batch
vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1)
# select hypothesis scores
s_new = torch.index_select(s.view(-1), 0, vidx)
s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim)
# convert ids to BHS space (S: scoring_num)
if scoring_idmap is not None:
snum = self.scoring_num
hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view(-1)
label_ids = torch.fmod(best_ids, self.odim).view(-1)
score_idx = scoring_idmap[hyp_idx, label_ids]
score_idx[score_idx == -1] = 0
vidx = score_idx + hyp_idx * snum
else:
snum = self.odim
# select forward probabilities
r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view(-1, 2, n_bh)
return r_new, s_new, f_min, f_max
def extend_prob(self, x):
"""Extend CTC prob.
:param torch.Tensor x: input label posterior sequences (B, T, O)
"""
if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O)
# Pad the rest of posteriors in the batch
# TODO(takaaki-hori): need a better way without for-loops
xlens = [x.size(1)]
for i, l in enumerate(xlens):
if l < self.input_length:
x[i, l:, :] = self.logzero
x[i, l:, self.blank] = 0
tmp_x = self.x
xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O)
xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim)
self.x = torch.stack([xn, xb]) # (2, T, B, O)
self.x[:, : tmp_x.shape[1], :, :] = tmp_x
self.input_length = x.size(1)
self.end_frames = torch.as_tensor(xlens) - 1
def extend_state(self, state):
"""Compute CTC prefix state.
:param state : CTC state
:return ctc_state
"""
if state is None:
# nothing to do
return state
else:
r_prev, s_prev, f_min_prev, f_max_prev = state
r_prev_new = torch.full(
(self.input_length, 2),
self.logzero,
dtype=self.dtype,
device=self.device,
)
start = max(r_prev.shape[0], 1)
r_prev_new[0:start] = r_prev
for t in range(start, self.input_length):
r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank]
return (r_prev_new, s_prev, f_min_prev, f_max_prev)
class CTCRescorerLogitsProcessor(LogitsProcessor):
def __init__(
self,
encoder_logits: torch.FloatTensor,
encoder_output_lens: torch.LongTensor,
pad_token_id: int,
eos_token_id: int,
ctc_margin: int,
ctc_weight: float,
num_beams: int,
space_token_id: int,
apply_eos_space_trick: bool,
eos_space_trick_weight: float,
debug: bool = False,
):
super().__init__()
# reduce_lens_by = (encoder_logits.argmax(dim=-1) == eos_token_id).sum(dim=-1)
# encoder_output_lens = encoder_output_lens - reduce_lens_by
self.pad_token_id = pad_token_id
self.ctc_prefix_scorer = CTCPrefixScoreTH(
torch.nn.functional.log_softmax(encoder_logits, dim=-1),
encoder_output_lens,
pad_token_id,
eos_token_id,
ctc_margin,
)
self.ctc_weight = ctc_weight
self.ctc_states = None
self.num_beams = num_beams
self.eos_token_id = eos_token_id
self.apply_eos_space_trick = apply_eos_space_trick
self.space_token_id = space_token_id
self.eos_space_trick_weight = eos_space_trick_weight
self.debug = debug
@staticmethod
def analyze_predictions(
scores, ctc_scores, next_token_scores, input_ids, k=10, tokenizer="Lakoc/english_corpus_uni5000_normalized"
):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
best_att_ids = scores.topk(k=k, dim=1)
best_ctc_ids = ctc_scores.topk(k=k, dim=1)
best_ids = next_token_scores.topk(k=k, dim=1)
def print_prediction(best_ids, name):
new_tensor = torch.zeros((best_ids.indices.shape[0], best_ids.indices.shape[1] * 2), dtype=torch.long)
new_tensor[:, 0::2] = best_ids.indices
new_tensor[:, 1::2] = 4976
print(f"{name}:")
for index, (next_ids, scores) in enumerate(zip(tokenizer.batch_decode(new_tensor), best_ids.values)):
print(f"HYP {index}:\n{next_ids} {scores}")
print(f"PREFIX:")
for index, prefix in enumerate(tokenizer.batch_decode(input_ids)):
print(f"HYP {index}:\n{prefix}")
print_prediction(best_att_ids, "ATT_SCORES")
print()
print_prediction(best_ctc_ids, "CTC_SCORES")
print()
print(f"CTC_EOS: {ctc_scores[:, 1]}")
print_prediction(best_ids, "NEXT_TOKEN_SCORES")
print()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores[:, self.pad_token_id] = self.ctc_prefix_scorer.logzero
if self.ctc_states is not None:
self.ctc_states = self.ctc_prefix_scorer.index_select_state(
self.ctc_states, input_ids[:, -1].reshape(-1, self.num_beams)
)
ctc_scores, ctc_states = self.ctc_prefix_scorer(input_ids, self.ctc_states)
self.ctc_states = ctc_states
next_token_scores = (1 - self.ctc_weight) * scores + self.ctc_weight * ctc_scores
if self.apply_eos_space_trick:
space_eos_conflict = torch.logical_and(
scores.argmax(dim=1) == self.eos_token_id, ctc_scores.argmax(dim=1) == self.space_token_id
)
if space_eos_conflict.any():
apply_trick_on = torch.logical_and(
torch.logical_and(
space_eos_conflict,
next_token_scores[:, self.eos_token_id] < next_token_scores[:, self.space_token_id],
),
self.eos_space_trick_weight * next_token_scores[:, self.eos_token_id]
> next_token_scores[:, self.space_token_id],
)
if apply_trick_on.any():
next_token_scores[apply_trick_on, self.eos_token_id] = (
next_token_scores[apply_trick_on, self.eos_token_id] * self.eos_space_trick_weight
)
if self.debug:
self.analyze_predictions(scores, ctc_scores, next_token_scores, input_ids)
return next_token_scores
class LogSoftmaxProcessor(LogitsProcessor):
def __init__(
self,
):
super().__init__()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores = torch.nn.functional.log_softmax(scores, dim=-1)
return scores
|