OFA-OCR / fairseq /tests /test_transformer.py
JustinLin610's picture
first commit
ee21b96
raw
history blame
1.94 kB
import argparse
import unittest
from typing import Any, Dict, Sequence
import torch
from fairseq.models import transformer
from tests.test_roberta import FakeTask
def mk_sample(tok: Sequence[int] = None, batch_size: int = 2) -> Dict[str, Any]:
if not tok:
tok = [10, 11, 12, 13, 14, 15, 2]
batch = torch.stack([torch.tensor(tok, dtype=torch.long)] * batch_size)
sample = {
"net_input": {
"src_tokens": batch,
"prev_output_tokens": batch,
"src_lengths": torch.tensor(
[len(tok)] * batch_size, dtype=torch.long, device=batch.device
),
},
"target": batch[:, 1:],
}
return sample
def mk_transformer(**extra_args: Any):
overrides = {
# Use characteristics dimensions
"encoder_embed_dim": 12,
"encoder_ffn_embed_dim": 14,
"decoder_embed_dim": 12,
"decoder_ffn_embed_dim": 14,
# Disable dropout so we have comparable tests.
"dropout": 0,
"attention_dropout": 0,
"activation_dropout": 0,
"encoder_layerdrop": 0,
}
overrides.update(extra_args)
# Overrides the defaults from the parser
args = argparse.Namespace(**overrides)
transformer.tiny_architecture(args)
torch.manual_seed(0)
task = FakeTask(args)
return transformer.TransformerModel.build_model(args, task)
class TransformerTestCase(unittest.TestCase):
def test_forward_backward(self):
model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=12)
sample = mk_sample()
o, _ = model.forward(**sample["net_input"])
loss = o.sum()
loss.backward()
def test_different_encoder_decoder_embed_dim(self):
model = mk_transformer(encoder_embed_dim=12, decoder_embed_dim=16)
sample = mk_sample()
o, _ = model.forward(**sample["net_input"])
loss = o.sum()
loss.backward()