File size: 2,523 Bytes
8aa4e50
b444210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0467843
b444210
 
 
0467843
b444210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8aa4e50
9ceb090
b444210
9ceb090
b444210
 
 
 
 
9ceb090
b444210
 
 
 
 
 
 
0467843
9ceb090
b444210
9ceb090
b444210
 
 
 
 
 
 
9ceb090
 
 
b444210
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
pretrained_path: HaNguyen/IWSLT-ast-w2v2-mbart

lang: fr   #for the BLEU score detokenization
target_lang: fr_XX   # for mbart initialization
sample_rate: 16000


# URL for the HuggingFace model we want to load (BASE here)
wav2vec2_hub: LIA-AvignonUniversity/IWSLT2022-tamasheq-only

# wav2vec 2.0 specific parameters
wav2vec2_frozen: False

# Feature parameters (W2V2 etc)
features_dim: 768 # base wav2vec output dimension, for large replace by 1024

#projection for w2v
enc_dnn_layers: 1
enc_dnn_neurons: 1024

# Transformer
embedding_size: 256
d_model: 1024
activation: !name:torch.nn.GELU

# Outputs
blank_index: 1
label_smoothing: 0.1
pad_index: 1      # pad_index defined by mbart model
bos_index: 250008 # fr_XX bos_index defined by mbart model
eos_index: 2

# Decoding parameters
# Be sure that the bos and eos index match with the BPEs ones
min_decode_ratio: 0.0
max_decode_ratio: 0.25
valid_beam_size: 5
test_beam_size: 5


############################## models ################################
#wav2vec model
wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
  source: !ref <wav2vec2_hub>
  output_norm: True
  freeze: !ref <wav2vec2_frozen>
  save_path: wav2vec2_checkpoint

#linear projection
enc: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
  input_shape: [null, null, 768]
  activation: !ref <activation>
  dnn_blocks: 1
  dnn_neurons: 1024

#mBART
mbart_path: facebook/mbart-large-50-many-to-many-mmt
mbart_frozen: False
mBART: &id004 !new:speechbrain.lobes.models.huggingface_transformers.mbart.mBART
  source: !ref <mbart_path>
  freeze: !ref <mbart_frozen>
  save_path: mbart_checkpoint
  target_lang: !ref <target_lang>

log_softmax: !new:speechbrain.nnet.activations.Softmax
  apply_log: True

seq_lin: !new:torch.nn.Identity

modules:
  wav2vec2: !ref <wav2vec2>
  enc: !ref <enc>
  mBART: !ref <mBART>
model: !new:torch.nn.ModuleList
- [!ref <enc>]

valid_search: !new:speechbrain.decoders.S2SHFTextBasedBeamSearcher
  modules: [!ref <mBART>, null, null]
  vocab_size: 250054
  bos_index: 250008
  eos_index: 2
  min_decode_ratio: 0.0
  max_decode_ratio: 0.25
  beam_size: 5
  using_eos_threshold: True
  length_normalization: True

pretrainer: !new:speechbrain.utils.parameter_transfer.Pretrainer
    loadables:
      model: !ref <model>
      wav2vec2: !ref <wav2vec2>
      mBART: !ref <mBART>
    paths:
        wav2vec2: !ref <pretrained_path>/wav2vec2.ckpt
        model: !ref <pretrained_path>/model.ckpt
        mBART: !ref <pretrained_path>/mBART.ckpt