Commit
·
afb3566
0
Parent(s):
Initial model
Browse files- .gitattributes +1 -0
- Fine_Tune_XLS_R_on_Common_Voice_sr_300m_CV8.ipynb +0 -0
- README.md +132 -0
- added_tokens.json +1 -0
- config.json +108 -0
- eval.py +186 -0
- preprocessor_config.json +10 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +1 -0
- tokenizer_config.json +1 -0
- vocab.json +1 -0
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
|
Fine_Tune_XLS_R_on_Common_Voice_sr_300m_CV8.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
README.md
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- sr
|
4 |
+
license: apache-2.0
|
5 |
+
tags:
|
6 |
+
- automatic-speech-recognition
|
7 |
+
- mozilla-foundation/common_voice_8_0
|
8 |
+
- generated_from_trainer
|
9 |
+
- robust-speech-event
|
10 |
+
- xlsr-fine-tuning-week
|
11 |
+
datasets:
|
12 |
+
- common_voice
|
13 |
+
- name: Serbian comodoro Wav2Vec2 XLSR 300M CV8
|
14 |
+
results:
|
15 |
+
- task:
|
16 |
+
name: Automatic Speech Recognition
|
17 |
+
type: automatic-speech-recognition
|
18 |
+
dataset:
|
19 |
+
name: Common Voice 8
|
20 |
+
type: mozilla-foundation/common_voice_8_0
|
21 |
+
args: hsb
|
22 |
+
metrics:
|
23 |
+
- name: Test WER
|
24 |
+
type: wer
|
25 |
+
value: 48.3
|
26 |
+
- name: Test CER
|
27 |
+
type: cer
|
28 |
+
value: 18.5
|
29 |
+
---
|
30 |
+
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
31 |
+
should probably proofread and complete it, then remove this comment. -->
|
32 |
+
|
33 |
+
# Serbian wav2vec2-xls-r-300m-sr-cv8
|
34 |
+
|
35 |
+
This model is a fine-tuned version of [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) on the common_voice dataset.
|
36 |
+
It achieves the following results on the evaluation set:
|
37 |
+
- Loss: 1.7302
|
38 |
+
- Wer: 0.4825
|
39 |
+
- Cer: 0.1847
|
40 |
+
|
41 |
+
## Model description
|
42 |
+
|
43 |
+
More information needed
|
44 |
+
|
45 |
+
## Intended uses & limitations
|
46 |
+
|
47 |
+
More information needed
|
48 |
+
|
49 |
+
## Training and evaluation data
|
50 |
+
|
51 |
+
More information needed
|
52 |
+
|
53 |
+
## Training procedure
|
54 |
+
|
55 |
+
### Training hyperparameters
|
56 |
+
|
57 |
+
The following hyperparameters were used during training:
|
58 |
+
- learning_rate: 0.0001
|
59 |
+
- train_batch_size: 16
|
60 |
+
- eval_batch_size: 8
|
61 |
+
- seed: 42
|
62 |
+
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
|
63 |
+
- lr_scheduler_type: linear
|
64 |
+
- lr_scheduler_warmup_steps: 300
|
65 |
+
- num_epochs: 800
|
66 |
+
- mixed_precision_training: Native AMP
|
67 |
+
|
68 |
+
### Training results
|
69 |
+
|
70 |
+
| Training Loss | Epoch | Step | Validation Loss | Wer | Cer |
|
71 |
+
|:-------------:|:-----:|:-----:|:---------------:|:------:|:------:|
|
72 |
+
| 5.6536 | 15.0 | 1200 | 2.9744 | 1.0 | 1.0 |
|
73 |
+
| 2.7935 | 30.0 | 2400 | 1.6613 | 0.8998 | 0.4670 |
|
74 |
+
| 1.6538 | 45.0 | 3600 | 0.9248 | 0.6918 | 0.2699 |
|
75 |
+
| 1.2446 | 60.0 | 4800 | 0.9151 | 0.6452 | 0.2398 |
|
76 |
+
| 1.0766 | 75.0 | 6000 | 0.9110 | 0.5995 | 0.2207 |
|
77 |
+
| 0.9548 | 90.0 | 7200 | 1.0273 | 0.5921 | 0.2149 |
|
78 |
+
| 0.8919 | 105.0 | 8400 | 0.9929 | 0.5646 | 0.2117 |
|
79 |
+
| 0.8185 | 120.0 | 9600 | 1.0850 | 0.5483 | 0.2069 |
|
80 |
+
| 0.7692 | 135.0 | 10800 | 1.1001 | 0.5394 | 0.2055 |
|
81 |
+
| 0.7249 | 150.0 | 12000 | 1.1018 | 0.5380 | 0.1958 |
|
82 |
+
| 0.6786 | 165.0 | 13200 | 1.1344 | 0.5114 | 0.1941 |
|
83 |
+
| 0.6432 | 180.0 | 14400 | 1.1516 | 0.5054 | 0.1905 |
|
84 |
+
| 0.6009 | 195.0 | 15600 | 1.3149 | 0.5324 | 0.1991 |
|
85 |
+
| 0.5773 | 210.0 | 16800 | 1.2468 | 0.5124 | 0.1903 |
|
86 |
+
| 0.559 | 225.0 | 18000 | 1.2186 | 0.4956 | 0.1922 |
|
87 |
+
| 0.5298 | 240.0 | 19200 | 1.4483 | 0.5333 | 0.2085 |
|
88 |
+
| 0.5136 | 255.0 | 20400 | 1.2871 | 0.4802 | 0.1846 |
|
89 |
+
| 0.4824 | 270.0 | 21600 | 1.2891 | 0.4974 | 0.1885 |
|
90 |
+
| 0.4669 | 285.0 | 22800 | 1.3283 | 0.4942 | 0.1878 |
|
91 |
+
| 0.4511 | 300.0 | 24000 | 1.4502 | 0.5002 | 0.1994 |
|
92 |
+
| 0.4337 | 315.0 | 25200 | 1.4714 | 0.5035 | 0.1911 |
|
93 |
+
| 0.4221 | 330.0 | 26400 | 1.4971 | 0.5124 | 0.1962 |
|
94 |
+
| 0.3994 | 345.0 | 27600 | 1.4473 | 0.5007 | 0.1920 |
|
95 |
+
| 0.3892 | 360.0 | 28800 | 1.3904 | 0.4937 | 0.1887 |
|
96 |
+
| 0.373 | 375.0 | 30000 | 1.4971 | 0.4946 | 0.1902 |
|
97 |
+
| 0.3657 | 390.0 | 31200 | 1.4208 | 0.4900 | 0.1821 |
|
98 |
+
| 0.3559 | 405.0 | 32400 | 1.4648 | 0.4895 | 0.1835 |
|
99 |
+
| 0.3476 | 420.0 | 33600 | 1.4848 | 0.4946 | 0.1829 |
|
100 |
+
| 0.3276 | 435.0 | 34800 | 1.5597 | 0.4979 | 0.1873 |
|
101 |
+
| 0.3193 | 450.0 | 36000 | 1.7329 | 0.5040 | 0.1980 |
|
102 |
+
| 0.3078 | 465.0 | 37200 | 1.6379 | 0.4937 | 0.1882 |
|
103 |
+
| 0.3058 | 480.0 | 38400 | 1.5878 | 0.4942 | 0.1921 |
|
104 |
+
| 0.2987 | 495.0 | 39600 | 1.5590 | 0.4811 | 0.1846 |
|
105 |
+
| 0.2931 | 510.0 | 40800 | 1.6001 | 0.4825 | 0.1849 |
|
106 |
+
| 0.276 | 525.0 | 42000 | 1.7388 | 0.4942 | 0.1918 |
|
107 |
+
| 0.2702 | 540.0 | 43200 | 1.7037 | 0.4839 | 0.1866 |
|
108 |
+
| 0.2619 | 555.0 | 44400 | 1.6704 | 0.4755 | 0.1840 |
|
109 |
+
| 0.262 | 570.0 | 45600 | 1.6042 | 0.4751 | 0.1865 |
|
110 |
+
| 0.2528 | 585.0 | 46800 | 1.6402 | 0.4821 | 0.1865 |
|
111 |
+
| 0.2442 | 600.0 | 48000 | 1.6693 | 0.4886 | 0.1862 |
|
112 |
+
| 0.244 | 615.0 | 49200 | 1.6203 | 0.4765 | 0.1792 |
|
113 |
+
| 0.2388 | 630.0 | 50400 | 1.6829 | 0.4830 | 0.1828 |
|
114 |
+
| 0.2362 | 645.0 | 51600 | 1.8100 | 0.4928 | 0.1888 |
|
115 |
+
| 0.2224 | 660.0 | 52800 | 1.7746 | 0.4932 | 0.1899 |
|
116 |
+
| 0.2218 | 675.0 | 54000 | 1.7752 | 0.4946 | 0.1901 |
|
117 |
+
| 0.2201 | 690.0 | 55200 | 1.6775 | 0.4788 | 0.1844 |
|
118 |
+
| 0.2147 | 705.0 | 56400 | 1.7085 | 0.4844 | 0.1851 |
|
119 |
+
| 0.2103 | 720.0 | 57600 | 1.7624 | 0.4848 | 0.1864 |
|
120 |
+
| 0.2101 | 735.0 | 58800 | 1.7213 | 0.4783 | 0.1835 |
|
121 |
+
| 0.1983 | 750.0 | 60000 | 1.7452 | 0.4848 | 0.1856 |
|
122 |
+
| 0.2015 | 765.0 | 61200 | 1.7525 | 0.4872 | 0.1869 |
|
123 |
+
| 0.1969 | 780.0 | 62400 | 1.7443 | 0.4844 | 0.1852 |
|
124 |
+
| 0.2043 | 795.0 | 63600 | 1.7302 | 0.4825 | 0.1847 |
|
125 |
+
|
126 |
+
|
127 |
+
### Framework versions
|
128 |
+
|
129 |
+
- Transformers 4.16.2
|
130 |
+
- Pytorch 1.10.1+cu102
|
131 |
+
- Datasets 1.18.3
|
132 |
+
- Tokenizers 0.11.0
|
added_tokens.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"<s>": 33, "</s>": 34}
|
config.json
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "facebook/wav2vec2-xls-r-300m",
|
3 |
+
"activation_dropout": 0.0,
|
4 |
+
"adapter_kernel_size": 3,
|
5 |
+
"adapter_stride": 2,
|
6 |
+
"add_adapter": false,
|
7 |
+
"apply_spec_augment": true,
|
8 |
+
"architectures": [
|
9 |
+
"Wav2Vec2ForCTC"
|
10 |
+
],
|
11 |
+
"attention_dropout": 0.2,
|
12 |
+
"bos_token_id": 1,
|
13 |
+
"classifier_proj_size": 256,
|
14 |
+
"codevector_dim": 768,
|
15 |
+
"contrastive_logits_temperature": 0.1,
|
16 |
+
"conv_bias": true,
|
17 |
+
"conv_dim": [
|
18 |
+
512,
|
19 |
+
512,
|
20 |
+
512,
|
21 |
+
512,
|
22 |
+
512,
|
23 |
+
512,
|
24 |
+
512
|
25 |
+
],
|
26 |
+
"conv_kernel": [
|
27 |
+
10,
|
28 |
+
3,
|
29 |
+
3,
|
30 |
+
3,
|
31 |
+
3,
|
32 |
+
2,
|
33 |
+
2
|
34 |
+
],
|
35 |
+
"conv_stride": [
|
36 |
+
5,
|
37 |
+
2,
|
38 |
+
2,
|
39 |
+
2,
|
40 |
+
2,
|
41 |
+
2,
|
42 |
+
2
|
43 |
+
],
|
44 |
+
"ctc_loss_reduction": "mean",
|
45 |
+
"ctc_zero_infinity": false,
|
46 |
+
"diversity_loss_weight": 0.1,
|
47 |
+
"do_stable_layer_norm": true,
|
48 |
+
"eos_token_id": 2,
|
49 |
+
"feat_extract_activation": "gelu",
|
50 |
+
"feat_extract_dropout": 0.0,
|
51 |
+
"feat_extract_norm": "layer",
|
52 |
+
"feat_proj_dropout": 0.1,
|
53 |
+
"feat_quantizer_dropout": 0.0,
|
54 |
+
"final_dropout": 0.0,
|
55 |
+
"gradient_checkpointing": false,
|
56 |
+
"hidden_act": "gelu",
|
57 |
+
"hidden_dropout": 0.2,
|
58 |
+
"hidden_size": 1024,
|
59 |
+
"initializer_range": 0.02,
|
60 |
+
"intermediate_size": 4096,
|
61 |
+
"layer_norm_eps": 1e-05,
|
62 |
+
"layerdrop": 0.4,
|
63 |
+
"mask_feature_length": 10,
|
64 |
+
"mask_feature_min_masks": 0,
|
65 |
+
"mask_feature_prob": 0.0,
|
66 |
+
"mask_time_length": 10,
|
67 |
+
"mask_time_min_masks": 2,
|
68 |
+
"mask_time_prob": 0.5,
|
69 |
+
"model_type": "wav2vec2",
|
70 |
+
"num_adapter_layers": 3,
|
71 |
+
"num_attention_heads": 16,
|
72 |
+
"num_codevector_groups": 2,
|
73 |
+
"num_codevectors_per_group": 320,
|
74 |
+
"num_conv_pos_embedding_groups": 16,
|
75 |
+
"num_conv_pos_embeddings": 128,
|
76 |
+
"num_feat_extract_layers": 7,
|
77 |
+
"num_hidden_layers": 24,
|
78 |
+
"num_negatives": 100,
|
79 |
+
"output_hidden_size": 1024,
|
80 |
+
"pad_token_id": 32,
|
81 |
+
"proj_codevector_dim": 768,
|
82 |
+
"tdnn_dilation": [
|
83 |
+
1,
|
84 |
+
2,
|
85 |
+
3,
|
86 |
+
1,
|
87 |
+
1
|
88 |
+
],
|
89 |
+
"tdnn_dim": [
|
90 |
+
512,
|
91 |
+
512,
|
92 |
+
512,
|
93 |
+
512,
|
94 |
+
1500
|
95 |
+
],
|
96 |
+
"tdnn_kernel": [
|
97 |
+
5,
|
98 |
+
3,
|
99 |
+
3,
|
100 |
+
1,
|
101 |
+
1
|
102 |
+
],
|
103 |
+
"torch_dtype": "float32",
|
104 |
+
"transformers_version": "4.16.2",
|
105 |
+
"use_weighted_layer_sum": false,
|
106 |
+
"vocab_size": 35,
|
107 |
+
"xvector_output_dim": 512
|
108 |
+
}
|
eval.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
from datasets import load_dataset, load_metric, Audio, Dataset
|
3 |
+
from transformers import pipeline, AutoFeatureExtractor, AutoTokenizer, Wav2Vec2ForCTC
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import argparse
|
7 |
+
import unicodedata
|
8 |
+
from typing import Dict
|
9 |
+
|
10 |
+
|
11 |
+
def log_results(result: Dataset, args: Dict[str, str]):
|
12 |
+
""" DO NOT CHANGE. This function computes and logs the result metrics. """
|
13 |
+
|
14 |
+
log_outputs = args.log_outputs
|
15 |
+
dataset_id = "_".join(args.dataset.split("/") + [args.config, args.split])
|
16 |
+
|
17 |
+
# load metric
|
18 |
+
wer = load_metric("wer")
|
19 |
+
cer = load_metric("cer")
|
20 |
+
|
21 |
+
# compute metrics
|
22 |
+
wer_result = wer.compute(references=result["target"], predictions=result["prediction"])
|
23 |
+
cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
|
24 |
+
|
25 |
+
# print & log results
|
26 |
+
result_str = (
|
27 |
+
f"WER: {wer_result}\n"
|
28 |
+
f"CER: {cer_result}"
|
29 |
+
)
|
30 |
+
print(result_str)
|
31 |
+
|
32 |
+
with open(f"{dataset_id}_eval_results.txt", "w") as f:
|
33 |
+
f.write(result_str)
|
34 |
+
|
35 |
+
# log all results in text file. Possibly interesting for analysis
|
36 |
+
if log_outputs is not None:
|
37 |
+
pred_file = f"log_{dataset_id}_predictions.txt"
|
38 |
+
target_file = f"log_{dataset_id}_targets.txt"
|
39 |
+
|
40 |
+
with open(pred_file, "w") as p, open(target_file, "w") as t:
|
41 |
+
|
42 |
+
# mapping function to write output
|
43 |
+
def write_to_file(batch, i):
|
44 |
+
p.write(f"{i}" + "\n")
|
45 |
+
p.write(batch["prediction"] + "\n")
|
46 |
+
t.write(f"{i}" + "\n")
|
47 |
+
t.write(batch["target"] + "\n")
|
48 |
+
|
49 |
+
result.map(write_to_file, with_indices=True)
|
50 |
+
|
51 |
+
|
52 |
+
def normalize_text(text: str) -> str:
|
53 |
+
""" DO ADAPT FOR YOUR USE CASE. this function normalizes the target text. """
|
54 |
+
|
55 |
+
|
56 |
+
CHARS = {
|
57 |
+
'ü': 'ue',
|
58 |
+
'ö': 'oe',
|
59 |
+
'ï': 'i',
|
60 |
+
'ë': 'e',
|
61 |
+
'ä': 'ae',
|
62 |
+
'ã': 'a',
|
63 |
+
'à': 'á',
|
64 |
+
'ø': 'o',
|
65 |
+
'è': 'é',
|
66 |
+
'ê': 'é',
|
67 |
+
'å': 'ó',
|
68 |
+
'î': 'i',
|
69 |
+
'ñ': 'ň',
|
70 |
+
'ç': 's',
|
71 |
+
'ľ': 'l',
|
72 |
+
'ż': 'ž',
|
73 |
+
'ł': 'w',
|
74 |
+
'ć': 'č',
|
75 |
+
'þ': 't',
|
76 |
+
'ß': 'ss',
|
77 |
+
'ę': 'en',
|
78 |
+
'ą': 'an',
|
79 |
+
'æ': 'ae',
|
80 |
+
}
|
81 |
+
|
82 |
+
def replace_chars(sentence):
|
83 |
+
result = ''
|
84 |
+
for ch in sentence:
|
85 |
+
new = CHARS[ch] if ch in CHARS else ch
|
86 |
+
result += new
|
87 |
+
|
88 |
+
return result
|
89 |
+
|
90 |
+
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\/\"\“\„\%\”\�\–\'\`\«\»\—\’\…]'
|
91 |
+
|
92 |
+
text = text.lower()
|
93 |
+
# normalize non-standard (stylized) unicode characters
|
94 |
+
text = unicodedata.normalize('NFKC', text)
|
95 |
+
# remove punctuation
|
96 |
+
text = re.sub(chars_to_ignore_regex, "", text)
|
97 |
+
text = replace_chars(text)
|
98 |
+
|
99 |
+
# Let's also make sure we split on all kinds of newlines, spaces, etc...
|
100 |
+
text = " ".join(text.split())
|
101 |
+
|
102 |
+
return text
|
103 |
+
|
104 |
+
|
105 |
+
def main(args):
|
106 |
+
# load dataset
|
107 |
+
dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
|
108 |
+
|
109 |
+
# for testing: only process the first two examples as a test
|
110 |
+
if args.limit:
|
111 |
+
dataset = dataset.select(range(limit))
|
112 |
+
|
113 |
+
|
114 |
+
asr = None
|
115 |
+
feature_extractor = None
|
116 |
+
|
117 |
+
if not args.model_id and not args.path:
|
118 |
+
raise RuntimeError('No model given!')
|
119 |
+
|
120 |
+
if not args.model_id:
|
121 |
+
model = Wav2Vec2ForCTC.from_pretrained(args.path)
|
122 |
+
tokenizer = AutoTokenizer.from_pretrained(args.path)
|
123 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(args.path)
|
124 |
+
|
125 |
+
# load eval pipeline
|
126 |
+
asr = pipeline("automatic-speech-recognition", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)
|
127 |
+
|
128 |
+
else:
|
129 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id)
|
130 |
+
asr = pipeline("automatic-speech-recognition", model=args.model_id)
|
131 |
+
|
132 |
+
# map function to decode audio
|
133 |
+
def map_to_pred(batch):
|
134 |
+
prediction = asr(batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s)
|
135 |
+
|
136 |
+
batch["prediction"] = prediction["text"]
|
137 |
+
batch["target"] = normalize_text(batch["sentence"])
|
138 |
+
return batch
|
139 |
+
|
140 |
+
# load processor
|
141 |
+
sampling_rate = feature_extractor.sampling_rate
|
142 |
+
|
143 |
+
# resample audio
|
144 |
+
dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
|
145 |
+
|
146 |
+
# run inference on all examples
|
147 |
+
result = dataset.map(map_to_pred, remove_columns=dataset.column_names)
|
148 |
+
|
149 |
+
# compute and log_results
|
150 |
+
# do not change function below
|
151 |
+
log_results(result, args)
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
parser = argparse.ArgumentParser()
|
156 |
+
|
157 |
+
parser.add_argument(
|
158 |
+
"--model_id", type=str, help="Model identifier. Should be loadable with 🤗 Transformers", default=''
|
159 |
+
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--dataset", type=str, required=True, help="Dataset name to evaluate the model. Should be loadable with 🤗 Datasets"
|
162 |
+
)
|
163 |
+
parser.add_argument(
|
164 |
+
"--config", type=str, required=True, help="Config of the dataset. *E.g.* `'en'` for Common Voice"
|
165 |
+
)
|
166 |
+
parser.add_argument(
|
167 |
+
"--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`"
|
168 |
+
)
|
169 |
+
parser.add_argument(
|
170 |
+
"--chunk_length_s", type=float, default=None, help="Chunk length in seconds. Defaults to None. For long audio files a good value would be 5.0 seconds."
|
171 |
+
)
|
172 |
+
parser.add_argument(
|
173 |
+
"--stride_length_s", type=float, default=None, help="Stride of the audio chunks. Defaults to None. For long audio files a good value would be 1.0 seconds."
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--log_outputs", action='store_true', help="If defined, write outputs to log file for analysis."
|
177 |
+
)
|
178 |
+
parser.add_argument(
|
179 |
+
"--path", type=str, help="If set and model_id is not set, use local model from this path.", default=''
|
180 |
+
)
|
181 |
+
parser.add_argument(
|
182 |
+
"--limit", type=int, help="Not required. If greater than zero, select a subset of this size from the dataset.", default=0
|
183 |
+
)
|
184 |
+
args = parser.parse_args()
|
185 |
+
|
186 |
+
main(args)
|
preprocessor_config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"do_normalize": true,
|
3 |
+
"feature_extractor_type": "Wav2Vec2FeatureExtractor",
|
4 |
+
"feature_size": 1,
|
5 |
+
"padding_side": "right",
|
6 |
+
"padding_value": 0.0,
|
7 |
+
"processor_class": "Wav2Vec2Processor",
|
8 |
+
"return_attention_mask": true,
|
9 |
+
"sampling_rate": 16000
|
10 |
+
}
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4837db0336473532f001e0c66f905b7e4e446b9e5c4fe5f19dbbd0f2e8184002
|
3 |
+
size 1262067185
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]", "additional_special_tokens": [{"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}]}
|
tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "[UNK]", "bos_token": "<s>", "eos_token": "</s>", "pad_token": "[PAD]", "do_lower_case": false, "word_delimiter_token": "|", "special_tokens_map_file": null, "tokenizer_file": null, "name_or_path": "./", "tokenizer_class": "Wav2Vec2CTCTokenizer", "processor_class": "Wav2Vec2Processor"}
|
vocab.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"а": 1, "б": 2, "в": 3, "г": 4, "д": 5, "е": 6, "ж": 7, "з": 8, "и": 9, "к": 10, "л": 11, "м": 12, "н": 13, "о": 14, "п": 15, "р": 16, "с": 17, "т": 18, "у": 19, "ф": 20, "х": 21, "ц": 22, "ч": 23, "ш": 24, "ђ": 25, "ј": 26, "љ": 27, "њ": 28, "ћ": 29, "џ": 30, "|": 0, "[UNK]": 31, "[PAD]": 32}
|