menevsem commited on
Commit
3652f8d
·
1 Parent(s): e64c9be

Create params.yaml

Browse files

This parameters are used while training wav2vec model on NMSQA dataset.

Files changed (1) hide show
  1. params.yaml +44 -0
params.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mode: "train"
2
+ pretrained_model_name: "facebook/wav2vec2-base-960h"
3
+ freeze_pretrained: True
4
+ output_dir: "out/model_asr_base_alldata"
5
+ dataset_script: "dataset/asr_dataset.py"
6
+ part_name: "nmsqa_all_asr"
7
+ window_size: 321
8
+ learning_rate: 0.0002
9
+ group_by_length: True
10
+ per_device_train_batch_size: 2
11
+ num_train_epochs: 10
12
+ fp16: True
13
+ save_steps: 1000
14
+ logging_steps: 500
15
+ warmup_steps: 500
16
+ weight_decay: 0.0005
17
+ load_best_model_at_end: True
18
+ metric_for_best_model: wer
19
+ greater_is_better: True
20
+ save_total_limit: 2
21
+ eval: True
22
+ per_device_eval_batch_size: 2
23
+ eval_steps: 1000
24
+ evaluation_strategy: "steps"
25
+ per_device_test_batch_size: 2
26
+ shuffle: False
27
+ attention_dropout: 0.0
28
+ hidden_dropout: 0.0
29
+ feat_proj_dropout: 0.0
30
+ mask_time_prob: 0.05
31
+ layer_dropout: 0.0
32
+ ctc_loss_reduction: "mean"
33
+ unk: "[UNK]"
34
+ pad: "[PAD]"
35
+ word_delimited: "|"
36
+ sampling_rate: 16000
37
+ padding_value: 0.0
38
+ feature_size: 1
39
+ normalize: True
40
+ gradient_accumulation_steps: 2
41
+ gradient_checkpointing: True
42
+ max_answer_length : 500
43
+ test_weight: 0.6
44
+ eval_part_name: validation