Spaces:
Runtime error
Runtime error
File size: 4,578 Bytes
10b0761 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
#!/usr/bin/env python3
# import models/encoder/decoder to be tested
from examples.speech_recognition.models.vggtransformer import (
TransformerDecoder,
VGGTransformerEncoder,
VGGTransformerModel,
vggtransformer_1,
vggtransformer_2,
vggtransformer_base,
)
# import base test class
from .asr_test_base import (
DEFAULT_TEST_VOCAB_SIZE,
TestFairseqDecoderBase,
TestFairseqEncoderBase,
TestFairseqEncoderDecoderModelBase,
get_dummy_dictionary,
get_dummy_encoder_output,
get_dummy_input,
)
class VGGTransformerModelTest_mid(TestFairseqEncoderDecoderModelBase):
def setUp(self):
def override_config(args):
"""
vggtrasformer_1 use 14 layers of transformer,
for testing purpose, it is too expensive. For fast turn-around
test, reduce the number of layers to 3.
"""
args.transformer_enc_config = (
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3"
)
super().setUp()
extra_args_setter = [vggtransformer_1, override_config]
self.setUpModel(VGGTransformerModel, extra_args_setter)
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
class VGGTransformerModelTest_big(TestFairseqEncoderDecoderModelBase):
def setUp(self):
def override_config(args):
"""
vggtrasformer_2 use 16 layers of transformer,
for testing purpose, it is too expensive. For fast turn-around
test, reduce the number of layers to 3.
"""
args.transformer_enc_config = (
"((1024, 16, 4096, True, 0.15, 0.15, 0.15),) * 3"
)
super().setUp()
extra_args_setter = [vggtransformer_2, override_config]
self.setUpModel(VGGTransformerModel, extra_args_setter)
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
class VGGTransformerModelTest_base(TestFairseqEncoderDecoderModelBase):
def setUp(self):
def override_config(args):
"""
vggtrasformer_base use 12 layers of transformer,
for testing purpose, it is too expensive. For fast turn-around
test, reduce the number of layers to 3.
"""
args.transformer_enc_config = (
"((512, 8, 2048, True, 0.15, 0.15, 0.15),) * 3"
)
super().setUp()
extra_args_setter = [vggtransformer_base, override_config]
self.setUpModel(VGGTransformerModel, extra_args_setter)
self.setUpInput(get_dummy_input(T=50, D=80, B=5, K=DEFAULT_TEST_VOCAB_SIZE))
class VGGTransformerEncoderTest(TestFairseqEncoderBase):
def setUp(self):
super().setUp()
self.setUpInput(get_dummy_input(T=50, D=80, B=5))
def test_forward(self):
print("1. test standard vggtransformer")
self.setUpEncoder(VGGTransformerEncoder(input_feat_per_channel=80))
super().test_forward()
print("2. test vggtransformer with limited right context")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80, transformer_context=(-1, 5)
)
)
super().test_forward()
print("3. test vggtransformer with limited left context")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80, transformer_context=(5, -1)
)
)
super().test_forward()
print("4. test vggtransformer with limited right context and sampling")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80,
transformer_context=(-1, 12),
transformer_sampling=(2, 2),
)
)
super().test_forward()
print("5. test vggtransformer with windowed context and sampling")
self.setUpEncoder(
VGGTransformerEncoder(
input_feat_per_channel=80,
transformer_context=(12, 12),
transformer_sampling=(2, 2),
)
)
class TransformerDecoderTest(TestFairseqDecoderBase):
def setUp(self):
super().setUp()
dict = get_dummy_dictionary(vocab_size=DEFAULT_TEST_VOCAB_SIZE)
decoder = TransformerDecoder(dict)
dummy_encoder_output = get_dummy_encoder_output(encoder_out_shape=(50, 5, 256))
self.setUpDecoder(decoder)
self.setUpInput(dummy_encoder_output)
self.setUpPrevOutputTokens()
|