|
mode: "train" |
|
pretrained_model_name: "facebook/wav2vec2-base-960h" |
|
freeze_pretrained: True |
|
output_dir: "out/model_asr_base_alldata" |
|
dataset_script: "dataset/asr_dataset.py" |
|
part_name: "nmsqa_all_asr" |
|
window_size: 321 |
|
learning_rate: 0.0002 |
|
group_by_length: True |
|
per_device_train_batch_size: 2 |
|
num_train_epochs: 10 |
|
fp16: True |
|
save_steps: 1000 |
|
logging_steps: 500 |
|
warmup_steps: 500 |
|
weight_decay: 0.0005 |
|
load_best_model_at_end: True |
|
metric_for_best_model: wer |
|
greater_is_better: True |
|
save_total_limit: 2 |
|
eval: True |
|
per_device_eval_batch_size: 2 |
|
eval_steps: 1000 |
|
evaluation_strategy: "steps" |
|
per_device_test_batch_size: 2 |
|
shuffle: False |
|
attention_dropout: 0.0 |
|
hidden_dropout: 0.0 |
|
feat_proj_dropout: 0.0 |
|
mask_time_prob: 0.05 |
|
layer_dropout: 0.0 |
|
ctc_loss_reduction: "mean" |
|
unk: "[UNK]" |
|
pad: "[PAD]" |
|
word_delimited: "|" |
|
sampling_rate: 16000 |
|
padding_value: 0.0 |
|
feature_size: 1 |
|
normalize: True |
|
gradient_accumulation_steps: 2 |
|
gradient_checkpointing: True |
|
max_answer_length : 500 |
|
test_weight: 0.6 |
|
eval_part_name: validation |