OFA-OCR / fairseq /tests /speech_recognition /test_vggtransformer.py
JustinLin610's picture
first commit
ee21b96
raw
history blame
4.58 kB
#!/usr/bin/env python3
# import models/encoder/decoder to be tested
from examples.speech_recognition.models.vggtransformer import (
TransformerDecoder,
VGGTransformerEncoder,
VGGTransformerModel,
vggtransformer_1,
vggtransformer_2,
vggtransformer_base,
)
# import base test class
from .asr_test_base import (
DEFAULT_TEST_VOCAB_SIZE,
TestFairseqDecoderBase,
TestFairseqEncoderBase,
TestFairseqEncoderDecoderModelBase,
get_dummy_dictionary,
get_dummy_encoder_output,
get_dummy_input,
)
class VGGTransformerModelTest_mid(TestFairseqEncoderDecoderModelBase):
def setUp(self):
def override_config(args):
"""
vggtrasformer_1 use 14 layers of transformer,
for testing purpose, it is too expensive. For fast turn-around
test, reduce the number of layers to 3.
"""
args.transformer_enc_config = (
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3"
)
super().setUp()
extra_args_setter = [vggtransformer_1, override_config]
self.setUpModel(VGGTransformerModel, extra_args_setter)
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
class VGGTransformerModelTest_big(TestFairseqEncoderDecoderModelBase):
def setUp(self):
def override_config(args):
"""
vggtrasformer_2 use 16 layers of transformer,
for testing purpose, it is too expensive. For fast turn-around
test, reduce the number of layers to 3.
"""
args.transformer_enc_config = (
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3"
)
super().setUp()
extra_args_setter = [vggtransformer_2, override_config]
self.setUpModel(VGGTransformerModel, extra_args_setter)
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
class VGGTransformerModelTest_base(TestFairseqEncoderDecoderModelBase):
def setUp(self):
def override_config(args):
"""
vggtrasformer_base use 12 layers of transformer,
for testing purpose, it is too expensive. For fast turn-around
test, reduce the number of layers to 3.
"""
args.transformer_enc_config = (
"((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 3"
)
super().setUp()
extra_args_setter = [vggtransformer_base, override_config]
self.setUpModel(VGGTransformerModel, extra_args_setter)
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
class VGGTransformerEncoderTest(TestFairseqEncoderBase):
def setUp(self):
super().setUp()
self.setUpInput(get_dummy_input(T=50, D=80, B=5))
def test_forward(self):
print("1. test standard vggtransformer")
self.setUpEncoder(VGGTransformerEncoder(input_feat_per_channel=80))
super().test_forward()
print("2. test vggtransformer with limited right context")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80, transformer_context=(-1, 5)
)
)
super().test_forward()
print("3. test vggtransformer with limited left context")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80, transformer_context=(5, -1)
)
)
super().test_forward()
print("4. test vggtransformer with limited right context and sampling")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80,
transformer_context=(-1, 12),
transformer_sampling=(2, 2),
)
)
super().test_forward()
print("5. test vggtransformer with windowed context and sampling")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80,
transformer_context=(12, 12),
transformer_sampling=(2, 2),
)
)
class TransformerDecoderTest(TestFairseqDecoderBase):
def setUp(self):
super().setUp()
dict = get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE)
decoder = TransformerDecoder(dict)
dummy_encoder_output = get_dummy_encoder_output(encoder_out_shape=(50, 5, 256))
self.setUpDecoder(decoder)
self.setUpInput(dummy_encoder_output)
self.setUpPrevOutputTokens()