Spaces:
Runtime error
Runtime error
import argparse | |
import unittest | |
from typing import Any, Dict | |
import torch | |
from examples.simultaneous_translation.models import ( | |
transformer_monotonic_attention | |
) | |
from tests.test_roberta import FakeTask | |
DEFAULT_CONFIG = { | |
"attention_eps": 1e-6, | |
"mass_preservation": True, | |
"noise_type": "flat", | |
"noise_mean": 0.0, | |
"noise_var": 1.0, | |
"energy_bias_init": -2, | |
"energy_bias": True | |
} | |
PAD_INDEX = 1 | |
def generate_config(overrides_kv): | |
new_dict = {key: value for key, value in DEFAULT_CONFIG.items()} | |
for key, value in overrides_kv.items(): | |
new_dict[key] = value | |
return new_dict | |
def make_sample_with_padding(longer_src=False) -> Dict[str, Any]: | |
tokens_1 = torch.LongTensor( | |
[ | |
[2, 10, 11, 12, 13, 14, 15, 10, 11, 12, 13, 14, 15, 2], | |
[ | |
2, 11, 12, 14, 15, 10, 11, 12, 13, 14, 15, 2, | |
PAD_INDEX, PAD_INDEX | |
], | |
] | |
) | |
tokens_2 = torch.LongTensor( | |
[ | |
[2, 11, 12, 13, 14, 2, PAD_INDEX, PAD_INDEX], | |
[2, 11, 22, 33, 2, PAD_INDEX, PAD_INDEX, PAD_INDEX] | |
] | |
) | |
if longer_src: | |
src_tokens = tokens_1[:, 1:] | |
prev_output_tokens = tokens_2 | |
else: | |
src_tokens = tokens_2[:, 1:8] | |
prev_output_tokens = tokens_1 | |
src_lengths = src_tokens.ne(PAD_INDEX).sum(dim=1).long() | |
sample = { | |
"net_input": { | |
"src_tokens": src_tokens, | |
"prev_output_tokens": prev_output_tokens, | |
"src_lengths": src_lengths, | |
}, | |
"target": prev_output_tokens[:, 1:], | |
} | |
return sample | |
def build_transformer_monotonic_attention(**extra_args: Any): | |
overrides = { | |
# Use characteristics dimensions | |
"encoder_embed_dim": 12, | |
"encoder_ffn_embed_dim": 14, | |
"decoder_embed_dim": 12, | |
"decoder_ffn_embed_dim": 14, | |
# Disable dropout so we have comparable tests. | |
"dropout": 0, | |
"attention_dropout": 0, | |
"activation_dropout": 0, | |
"encoder_layerdrop": 0, | |
} | |
overrides.update(extra_args) | |
# Overrides the defaults from the parser | |
args = argparse.Namespace(**overrides) | |
transformer_monotonic_attention.monotonic_tiny_architecture(args) | |
torch.manual_seed(0) | |
task = FakeTask(args) | |
return ( | |
transformer_monotonic_attention | |
.TransformerModelSimulTrans | |
.build_model(args, task) | |
) | |
def expected_alignment_formula( | |
p_choose, | |
mass_perservation=True, | |
padding_mask=None | |
): | |
# Online and Linear-Time Attention by Enforcing Monotonic Alignments | |
# https://arxiv.org/pdf/1704.00784.pdf | |
# Eq 18, 19 | |
bsz, tgt_len, src_len = p_choose.size() | |
alpha = torch.zeros_like(p_choose) | |
if padding_mask is not None: | |
bsz_pad = padding_mask.size(0) | |
num_heads = int(bsz / bsz_pad) | |
padding_mask = ( | |
padding_mask | |
.unsqueeze(1) | |
.expand([bsz_pad, num_heads, src_len]) | |
.contiguous() | |
.view(-1, src_len) | |
) | |
p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0) | |
for bsz_i in range(bsz): | |
for i in range(tgt_len): | |
for j in range(src_len): | |
if i == 0: | |
if j == 0: | |
# First source token | |
alpha[bsz_i, i, j] = p_choose[bsz_i, i, j] | |
else: | |
# First target token | |
alpha[bsz_i, i, j] = ( | |
p_choose[bsz_i, i, j] | |
* torch.prod( | |
1 - p_choose[bsz_i, i, :j] | |
) | |
) | |
else: | |
alpha[bsz_i, i, j] = alpha[bsz_i, i - 1, j] | |
for k in range(j): | |
alpha[bsz_i, i, j] += ( | |
alpha[bsz_i, i - 1, k] | |
* torch.prod( | |
1 - p_choose[bsz_i, i, k:j] | |
) | |
) | |
alpha[bsz_i, i, j] *= p_choose[bsz_i, i, j] | |
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0) | |
if mass_perservation: | |
alpha = mass_perservation_formula(alpha, False, padding_mask) | |
return alpha | |
def mass_perservation_formula(alpha, left_padding=False, padding_mask=None): | |
if padding_mask is None or alpha.size(-1) == 1: | |
if alpha.size(-1) > 1: | |
alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1) | |
return alpha | |
src_lens = (padding_mask.logical_not()).sum(dim=1).long() | |
bsz, tgt_len, src_len = alpha.size() | |
assert ( | |
not left_padding | |
or (left_padding and (not padding_mask[:, 0].any())) | |
) | |
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0) | |
for bsz_i in range(bsz): | |
if left_padding: | |
alpha[bsz_i, :, -1] = ( | |
1 - alpha[bsz_i, :, :-1].sum(dim=-1) | |
) | |
else: | |
alpha[bsz_i, :, src_lens[bsz_i] - 1] = ( | |
1 - alpha[bsz_i, :, :src_lens[bsz_i] - 1].sum(dim=-1) | |
) | |
return alpha | |
def expected_soft_attention_formula( | |
alpha, | |
soft_energy, | |
padding_mask=None, | |
chunksize=1e10, | |
): | |
# Monotonic Infinite Lookback Attention for Simultaneous Machine Translation | |
# https://arxiv.org/pdf/1906.05218.pdf | |
# Eq 14 | |
# Monotonic Chunkwise Attention | |
# https://arxiv.org/abs/1712.05382 | |
# Eq 17 | |
bsz, tgt_len, src_len = alpha.size() | |
beta = torch.zeros_like(alpha) | |
if padding_mask is not None: | |
bsz_pad = padding_mask.size(0) | |
num_heads = int(bsz / bsz_pad) | |
# Expanding for potential head dimension | |
padding_mask = ( | |
padding_mask | |
.unsqueeze(1) | |
.expand([bsz_pad, num_heads, src_len]) | |
.contiguous() | |
.view(-1, src_len) | |
) | |
soft_energy = soft_energy.masked_fill(padding_mask.unsqueeze(1), float('-inf')) | |
for bsz_i in range(bsz): | |
for i in range(tgt_len): | |
for j in range(src_len): | |
for k in range(j, min([src_len, j + chunksize])): | |
if not padding_mask[bsz_i, j]: | |
beta[bsz_i, i, j] += ( | |
alpha[bsz_i, i, k] * torch.exp(soft_energy[bsz_i, i, j]) | |
/ torch.sum(torch.exp(soft_energy[bsz_i, i, max([0, k - chunksize + 1]):k + 1])) | |
) | |
return beta | |
class MonotonicAttentionTestAbstractClass(object): | |
def test_forward(self): | |
sample = make_sample_with_padding() | |
out, _ = self.model.forward(**sample["net_input"]) | |
loss = out.sum() | |
loss.backward() | |
def test_p_choose(self): | |
sample = make_sample_with_padding() | |
_, extra_out = self.model.forward(**sample["net_input"]) | |
for item in extra_out.attn_list: | |
p_choose = item["p_choose"] | |
self.assertTrue(p_choose.le(1.0).all()) | |
self.assertTrue(p_choose.ge(0.0).all()) | |
def test_expected_alignment(self): | |
for longer_src in [True, False]: | |
sample = make_sample_with_padding(longer_src) | |
_, extra_out = self.model.forward(**sample["net_input"]) | |
for item in extra_out.attn_list: | |
p_choose = item["p_choose"] | |
alpha_system = item["alpha"] | |
self.assertTrue(p_choose.size() == alpha_system.size()) | |
bsz, num_head, tgt_len, src_len = alpha_system.size() | |
alpha_system = alpha_system.view(-1, tgt_len, src_len) | |
p_choose = p_choose.view(-1, tgt_len, src_len) | |
alpha_real = expected_alignment_formula( | |
p_choose, | |
self.model.decoder.layers[0].encoder_attn.mass_preservation, | |
sample["net_input"]["src_tokens"].eq(PAD_INDEX) | |
) | |
self.assertTrue( | |
torch.abs(alpha_system - alpha_real).le(5e-5).all(), | |
) | |
class HardMonotonicAttentionTestCase( | |
unittest.TestCase, | |
MonotonicAttentionTestAbstractClass | |
): | |
def setUp(self): | |
self.model = build_transformer_monotonic_attention( | |
**generate_config({"simul_type": "hard_aligned"}) | |
) | |
class InfiniteLookbackTestCase( | |
unittest.TestCase, | |
MonotonicAttentionTestAbstractClass | |
): | |
def setUp(self): | |
self.model = build_transformer_monotonic_attention( | |
**generate_config( | |
{ | |
"simul_type": "infinite_lookback" | |
} | |
) | |
) | |
self.model.train() | |
def test_fp16_for_long_input(self): | |
sample = { | |
"net_input": { | |
"src_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0), | |
"prev_output_tokens": torch.LongTensor([7] * 1000 + [2]).cuda().unsqueeze(0), | |
"src_lengths": torch.LongTensor([1000]).cuda(), | |
}, | |
"target": torch.LongTensor([2] + [7] * 1000).unsqueeze(0).cuda() | |
} | |
self.model.cuda().half() | |
_, extra_out = self.model.forward(**sample["net_input"]) | |
for item in extra_out.attn_list: | |
for key in ["p_choose", "alpha", "beta", "soft_energy"]: | |
self.assertFalse(torch.isnan(item[key]).any()) | |
def test_expected_attention(self): | |
for longer_src in [True, False]: | |
sample = make_sample_with_padding(longer_src) | |
_, extra_out = self.model.forward(**sample["net_input"]) | |
for item in extra_out.attn_list: | |
p_choose = item["p_choose"] | |
alpha_system = item["alpha"] | |
beta_system = item["beta"] | |
soft_energy_system = item["soft_energy"] | |
self.assertTrue(beta_system.size() == alpha_system.size()) | |
self.assertTrue(p_choose.size() == alpha_system.size()) | |
bsz, num_head, tgt_len, src_len = alpha_system.size() | |
alpha_system = alpha_system.view(-1, tgt_len, src_len) | |
beta_system = beta_system.view(-1, tgt_len, src_len) | |
p_choose = p_choose.view(-1, tgt_len, src_len) | |
soft_energy_system = soft_energy_system.view(-1, tgt_len, src_len) | |
alpha_real = expected_alignment_formula( | |
p_choose, | |
self.model.decoder.layers[0].encoder_attn.mass_preservation, | |
sample["net_input"]["src_tokens"].eq(PAD_INDEX) | |
) | |
beta_real = expected_soft_attention_formula( | |
alpha_real, | |
soft_energy_system, | |
sample["net_input"]["src_tokens"].eq(PAD_INDEX), | |
chunksize=getattr( | |
self.model.decoder.layers[0].encoder_attn, | |
"chunk_size", | |
int(1e10) | |
) | |
) | |
self.assertTrue( | |
torch.abs(beta_system - beta_real).le(1e-5).all(), | |
) | |
class ChunkwiswTestCase( | |
InfiniteLookbackTestCase | |
): | |
def setUp(self): | |
self.model = build_transformer_monotonic_attention( | |
**generate_config( | |
{ | |
"simul_type": "chunkwise", | |
"mocha_chunk_size": 3 | |
} | |
) | |
) | |
class WaitkTestCase(InfiniteLookbackTestCase): | |
def setUp(self): | |
self.model = build_transformer_monotonic_attention( | |
**generate_config( | |
{ | |
"simul_type": "waitk", | |
"waitk_lagging": 3, | |
} | |
) | |
) | |
def check_waitk(self, p_choose, lagging, padding_mask): | |
bsz, tgt_len, src_len = p_choose.size() | |
for bsz_i in range(bsz): | |
for i in range(tgt_len): | |
for j in range(src_len): | |
if not padding_mask[bsz_i, j]: | |
if j - i == lagging - 1: | |
self.assertTrue(p_choose[bsz_i, i, j] == 1) | |
else: | |
self.assertTrue(p_choose[bsz_i, i, j] == 0) | |
def test_waitk_p_choose(self): | |
for longer_src in [True, False]: | |
for k in [1, 3, 10, 20, 100]: | |
sample = make_sample_with_padding(longer_src) | |
model = build_transformer_monotonic_attention( | |
**generate_config( | |
{ | |
"simul_type": "waitk", | |
"waitk_lagging": k, | |
} | |
) | |
) | |
model.train() | |
_, extra_out = model.forward(**sample["net_input"]) | |
for item in extra_out.attn_list: | |
p_choose = item["p_choose"] | |
bsz, num_heads, tgt_len, src_len = p_choose.size() | |
padding_mask = sample["net_input"]["src_tokens"].eq(PAD_INDEX) | |
padding_mask = ( | |
padding_mask | |
.unsqueeze(1) | |
.expand([bsz, num_heads, src_len]) | |
.contiguous() | |
.view(-1, src_len) | |
) | |
p_choose = p_choose.view(bsz * num_heads, tgt_len, src_len) | |
self.check_waitk(p_choose, k, padding_mask) | |