File size: 4,146 Bytes
a1c5b19
eef5961
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1c5b19
 
 
 
eef5961
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1c5b19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
from transformers import PretrainedConfig

class RebornUASRConfig(PretrainedConfig):
    '''
    We can use this class to define the configuration of the reborn model. 
    The reborn UASR is composed of a segmenter, a discriminator, and a generator.
    We only include the required configurations for the discriminator and the generator from fairseq's wav2vec-U model configuration. 
    '''
    model_type = "reborn_uasr"
    
    def __init__(self, 
        segmenter_type: str = "cnn",
        segmenter_input_dim: int = 512,
        segmenter_hidden_dim: int = 512,
        segmenter_dropout: float = 0.1,
        segmenter_kernel_size: int = 7,

        discriminator_input_dim: int = 512,
        discriminator_kernel: int = 3,
        discriminator_dilation: int = 1,
        discriminator_dim: int = 256,
        discriminator_causal: bool = True,
        discriminator_linear_emb: bool = False,
        discriminator_depth: int = 1,
        discriminator_max_pool: bool = False,
        discriminator_act_after_linear: bool = False,
        discriminator_dropout: float = 0.0,
        discriminator_spectral_norm: bool = False,
        discriminator_weight_norm: bool = False,

        generator_input_dim: int = 512,
        generator_output_dim: int = 40,
        generator_kernel: int = 4,
        generator_dilation: int = 1,
        generator_stride: int = 1,
        generator_bias: bool = False,
        generator_dropout: float = 0.0,
        generator_bn_apply: bool = False,
        generator_bn_init_weight: float = 30.0,

        phones: list = [],
        dict_fpath: str = "",
        special_token_nums: int = 4, # [<s>, <pad>, </s>, <unk>]
        **kwargs
    ):
        super().__init__(**kwargs)
        # read in all the configurations
        self.segmenter_type = segmenter_type
        self.segmenter_input_dim = segmenter_input_dim
        self.segmenter_hidden_dim = segmenter_hidden_dim
        self.segmenter_dropout = segmenter_dropout
        self.segmenter_kernel_size = segmenter_kernel_size

        self.discriminator_input_dim = discriminator_input_dim
        self.discriminator_kernel = discriminator_kernel
        self.discriminator_dilation = discriminator_dilation
        self.discriminator_dim = discriminator_dim
        self.discriminator_causal = discriminator_causal
        self.discriminator_linear_emb = discriminator_linear_emb
        self.discriminator_depth = discriminator_depth
        self.discriminator_max_pool = discriminator_max_pool
        self.discriminator_act_after_linear = discriminator_act_after_linear
        self.discriminator_dropout = discriminator_dropout
        self.discriminator_spectral_norm = discriminator_spectral_norm
        self.discriminator_weight_norm = discriminator_weight_norm

        self.generator_input_dim = generator_input_dim
        self.generator_output_dim = generator_output_dim
        self.generator_kernel = generator_kernel
        self.generator_dilation = generator_dilation
        self.generator_stride = generator_stride
        self.generator_bias = generator_bias
        self.generator_dropout = generator_dropout
        self.generator_bn_apply = generator_bn_apply
        self.generator_bn_init_weight = generator_bn_init_weight

        self.special_token_nums = special_token_nums
        if os.path.isfile(dict_fpath):
            self.phones = self.read_phns_dict_from_fpath(dict_fpath)
        else:
            self.phones = phones
        if len(self.phones) > 0:
            self.generator_output_dim = len(self.phones) + self.special_token_nums
            self.discriminator_input_dim = self.generator_output_dim
    
    def read_phns_dict_from_fpath(self, fpath: str):
        phns = []
        with open(fpath, "r") as f:
            for l in f:
                phn = l.strip().split('\t')[0].split(' ')[0]
                phns.append(phn)
        return phns

def main():
    config = RebornUASRConfig(dict_fpath="/home/andybi7676/Desktop/uasr-rl/data/ls_100h_new/text/prep/phones/dict.phn.txt")
    print(config)
    config.save_pretrained("reborn_uasr")

if __name__ == "__main__":
    main()