AndrewMcDowell
commited on
Commit
•
5b14b52
1
Parent(s):
d3aab78
Revert "Model save"
Browse filesThis reverts commit b5f991fc8e218ffe66de3f71361915d16ff06459.
.ipynb_checkpoints/README-checkpoint.md
DELETED
@@ -1,100 +0,0 @@
|
|
1 |
-
---
|
2 |
-
language:
|
3 |
-
- de
|
4 |
-
license: apache-2.0
|
5 |
-
tags:
|
6 |
-
- automatic-speech-recognition
|
7 |
-
- mozilla-foundation/common_voice_8_0
|
8 |
-
- generated_from_trainer
|
9 |
-
datasets:
|
10 |
-
- common_voice
|
11 |
-
model-index:
|
12 |
-
- name: ''
|
13 |
-
results: []
|
14 |
-
---
|
15 |
-
|
16 |
-
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
|
17 |
-
should probably proofread and complete it, then remove this comment. -->
|
18 |
-
|
19 |
-
#
|
20 |
-
|
21 |
-
This model is a fine-tuned version of [facebook/wav2vec2-xls-r-1b](https://huggingface.co/facebook/wav2vec2-xls-r-1b) on the MOZILLA-FOUNDATION/COMMON_VOICE_8_0 - DE dataset.
|
22 |
-
It achieves the following results on the evaluation set:
|
23 |
-
- Loss: 0.1355
|
24 |
-
- Wer: 0.1532
|
25 |
-
|
26 |
-
## Model description
|
27 |
-
|
28 |
-
More information needed
|
29 |
-
|
30 |
-
## Intended uses & limitations
|
31 |
-
|
32 |
-
More information needed
|
33 |
-
|
34 |
-
## Training and evaluation data
|
35 |
-
|
36 |
-
More information needed
|
37 |
-
|
38 |
-
## Training procedure
|
39 |
-
|
40 |
-
### Training hyperparameters
|
41 |
-
|
42 |
-
The following hyperparameters were used during training:
|
43 |
-
- learning_rate: 7.5e-05
|
44 |
-
- train_batch_size: 8
|
45 |
-
- eval_batch_size: 8
|
46 |
-
- seed: 42
|
47 |
-
- gradient_accumulation_steps: 4
|
48 |
-
- total_train_batch_size: 32
|
49 |
-
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
|
50 |
-
- lr_scheduler_type: linear
|
51 |
-
- lr_scheduler_warmup_steps: 2000
|
52 |
-
- num_epochs: 2.5
|
53 |
-
- mixed_precision_training: Native AMP
|
54 |
-
|
55 |
-
### Training results
|
56 |
-
|
57 |
-
| Training Loss | Epoch | Step | Validation Loss | Wer |
|
58 |
-
|:-------------:|:-----:|:-----:|:---------------:|:------:|
|
59 |
-
| 1.0826 | 0.07 | 1000 | 0.4637 | 0.4654 |
|
60 |
-
| 1.118 | 0.15 | 2000 | 0.2595 | 0.2687 |
|
61 |
-
| 1.1268 | 0.22 | 3000 | 0.2635 | 0.2661 |
|
62 |
-
| 1.0919 | 0.29 | 4000 | 0.2417 | 0.2566 |
|
63 |
-
| 1.1013 | 0.37 | 5000 | 0.2414 | 0.2567 |
|
64 |
-
| 1.0898 | 0.44 | 6000 | 0.2546 | 0.2731 |
|
65 |
-
| 1.0808 | 0.51 | 7000 | 0.2399 | 0.2535 |
|
66 |
-
| 1.0719 | 0.59 | 8000 | 0.2353 | 0.2528 |
|
67 |
-
| 1.0446 | 0.66 | 9000 | 0.2427 | 0.2545 |
|
68 |
-
| 1.0347 | 0.73 | 10000 | 0.2266 | 0.2402 |
|
69 |
-
| 1.0457 | 0.81 | 11000 | 0.2290 | 0.2448 |
|
70 |
-
| 1.0124 | 0.88 | 12000 | 0.2295 | 0.2448 |
|
71 |
-
| 1.025 | 0.95 | 13000 | 0.2138 | 0.2345 |
|
72 |
-
| 1.0107 | 1.03 | 14000 | 0.2108 | 0.2294 |
|
73 |
-
| 0.9758 | 1.1 | 15000 | 0.2019 | 0.2204 |
|
74 |
-
| 0.9547 | 1.17 | 16000 | 0.2000 | 0.2178 |
|
75 |
-
| 0.986 | 1.25 | 17000 | 0.2018 | 0.2200 |
|
76 |
-
| 0.9588 | 1.32 | 18000 | 0.1992 | 0.2138 |
|
77 |
-
| 0.9413 | 1.39 | 19000 | 0.1898 | 0.2049 |
|
78 |
-
| 0.9339 | 1.47 | 20000 | 0.1874 | 0.2056 |
|
79 |
-
| 0.9268 | 1.54 | 21000 | 0.1797 | 0.1976 |
|
80 |
-
| 0.9194 | 1.61 | 22000 | 0.1743 | 0.1905 |
|
81 |
-
| 0.8987 | 1.69 | 23000 | 0.1738 | 0.1932 |
|
82 |
-
| 0.8884 | 1.76 | 24000 | 0.1703 | 0.1873 |
|
83 |
-
| 0.8939 | 1.83 | 25000 | 0.1633 | 0.1831 |
|
84 |
-
| 0.8629 | 1.91 | 26000 | 0.1549 | 0.1750 |
|
85 |
-
| 0.8607 | 1.98 | 27000 | 0.1550 | 0.1738 |
|
86 |
-
| 0.8316 | 2.05 | 28000 | 0.1512 | 0.1709 |
|
87 |
-
| 0.8321 | 2.13 | 29000 | 0.1481 | 0.1657 |
|
88 |
-
| 0.825 | 2.2 | 30000 | 0.1446 | 0.1627 |
|
89 |
-
| 0.8115 | 2.27 | 31000 | 0.1396 | 0.1583 |
|
90 |
-
| 0.7959 | 2.35 | 32000 | 0.1389 | 0.1569 |
|
91 |
-
| 0.7835 | 2.42 | 33000 | 0.1362 | 0.1545 |
|
92 |
-
| 0.7959 | 2.49 | 34000 | 0.1355 | 0.1531 |
|
93 |
-
|
94 |
-
|
95 |
-
### Framework versions
|
96 |
-
|
97 |
-
- Transformers 4.17.0.dev0
|
98 |
-
- Pytorch 1.10.2+cu102
|
99 |
-
- Datasets 1.18.2.dev0
|
100 |
-
- Tokenizers 0.11.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.ipynb_checkpoints/eval-checkpoint.py
DELETED
@@ -1,137 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
import argparse
|
3 |
-
import re
|
4 |
-
from typing import Dict
|
5 |
-
|
6 |
-
import torch
|
7 |
-
from datasets import Audio, Dataset, load_dataset, load_metric
|
8 |
-
|
9 |
-
from transformers import AutoFeatureExtractor, pipeline
|
10 |
-
|
11 |
-
|
12 |
-
def log_results(result: Dataset, args: Dict[str, str]):
|
13 |
-
"""DO NOT CHANGE. This function computes and logs the result metrics."""
|
14 |
-
|
15 |
-
log_outputs = args.log_outputs
|
16 |
-
dataset_id = "_".join(args.dataset.split("/") + [args.config, args.split])
|
17 |
-
|
18 |
-
# load metric
|
19 |
-
wer = load_metric("wer")
|
20 |
-
cer = load_metric("cer")
|
21 |
-
|
22 |
-
# compute metrics
|
23 |
-
wer_result = wer.compute(references=result["target"], predictions=result["prediction"])
|
24 |
-
cer_result = cer.compute(references=result["target"], predictions=result["prediction"])
|
25 |
-
|
26 |
-
# print & log results
|
27 |
-
result_str = f"WER: {wer_result}\n" f"CER: {cer_result}"
|
28 |
-
print(result_str)
|
29 |
-
|
30 |
-
with open(f"{dataset_id}_eval_results.txt", "w") as f:
|
31 |
-
f.write(result_str)
|
32 |
-
|
33 |
-
# log all results in text file. Possibly interesting for analysis
|
34 |
-
if log_outputs is not None:
|
35 |
-
pred_file = f"log_{dataset_id}_predictions.txt"
|
36 |
-
target_file = f"log_{dataset_id}_targets.txt"
|
37 |
-
|
38 |
-
with open(pred_file, "w") as p, open(target_file, "w") as t:
|
39 |
-
|
40 |
-
# mapping function to write output
|
41 |
-
def write_to_file(batch, i):
|
42 |
-
p.write(f"{i}" + "\n")
|
43 |
-
p.write(batch["prediction"] + "\n")
|
44 |
-
t.write(f"{i}" + "\n")
|
45 |
-
t.write(batch["target"] + "\n")
|
46 |
-
|
47 |
-
result.map(write_to_file, with_indices=True)
|
48 |
-
|
49 |
-
|
50 |
-
def normalize_text(text: str) -> str:
|
51 |
-
"""DO ADAPT FOR YOUR USE CASE. this function normalizes the target text."""
|
52 |
-
|
53 |
-
chars_to_ignore_regex = '[,?.!\-\;\:"“%‘”�—’…–$&()*+.\/=@\[\]_`¡§«°´µ·»×àáâãåæçèéêëìíîïðñòóôõøùúûýþāăąćčďđēėęěğġħīıłńņňōŏőœřśşšťūůźżžơǐǔșțəʻʾʿ̥̆̇авеикморсфчшѹאבנעש་ནḫṟṣṭạảắằếễệọồộụứ‑‚„‟′″‹›→−≡⟨⟩カ东临乡关合城孙尣幺支比毛泽無生臣辶道镇黃]' # noqa: W605 IMPORTANT: this should correspond to the chars that were ignored during training
|
54 |
-
|
55 |
-
text = re.sub(chars_to_ignore_regex, "", text.lower())
|
56 |
-
|
57 |
-
# In addition, we can normalize the target text, e.g. removing new lines characters etc...
|
58 |
-
# note that order is important here!
|
59 |
-
token_sequences_to_ignore = ["\n\n", "\n", " ", " "]
|
60 |
-
|
61 |
-
for t in token_sequences_to_ignore:
|
62 |
-
text = " ".join(text.split(t))
|
63 |
-
|
64 |
-
return text
|
65 |
-
|
66 |
-
|
67 |
-
def main(args):
|
68 |
-
# load dataset
|
69 |
-
dataset = load_dataset(args.dataset, args.config, split=args.split, use_auth_token=True)
|
70 |
-
|
71 |
-
# for testing: only process the first two examples as a test
|
72 |
-
# dataset = dataset.select(range(10))
|
73 |
-
|
74 |
-
# load processor
|
75 |
-
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_id)
|
76 |
-
sampling_rate = feature_extractor.sampling_rate
|
77 |
-
|
78 |
-
# resample audio
|
79 |
-
dataset = dataset.cast_column("audio", Audio(sampling_rate=sampling_rate))
|
80 |
-
|
81 |
-
# load eval pipeline
|
82 |
-
if args.device is None:
|
83 |
-
args.device = 0 if torch.cuda.is_available() else -1
|
84 |
-
asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)
|
85 |
-
|
86 |
-
# map function to decode audio
|
87 |
-
def map_to_pred(batch):
|
88 |
-
prediction = asr(
|
89 |
-
batch["audio"]["array"], chunk_length_s=args.chunk_length_s, stride_length_s=args.stride_length_s
|
90 |
-
)
|
91 |
-
|
92 |
-
batch["prediction"] = prediction["text"]
|
93 |
-
batch["target"] = normalize_text(batch["sentence"])
|
94 |
-
return batch
|
95 |
-
|
96 |
-
# run inference on all examples
|
97 |
-
result = dataset.map(map_to_pred, remove_columns=dataset.column_names)
|
98 |
-
|
99 |
-
# compute and log_results
|
100 |
-
# do not change function below
|
101 |
-
log_results(result, args)
|
102 |
-
|
103 |
-
|
104 |
-
if __name__ == "__main__":
|
105 |
-
parser = argparse.ArgumentParser()
|
106 |
-
|
107 |
-
parser.add_argument(
|
108 |
-
"--model_id", type=str, required=True, help="Model identifier. Should be loadable with 🤗 Transformers"
|
109 |
-
)
|
110 |
-
parser.add_argument(
|
111 |
-
"--dataset",
|
112 |
-
type=str,
|
113 |
-
required=True,
|
114 |
-
help="Dataset name to evaluate the `model_id`. Should be loadable with 🤗 Datasets",
|
115 |
-
)
|
116 |
-
parser.add_argument(
|
117 |
-
"--config", type=str, required=True, help="Config of the dataset. *E.g.* `'en'` for Common Voice"
|
118 |
-
)
|
119 |
-
parser.add_argument("--split", type=str, required=True, help="Split of the dataset. *E.g.* `'test'`")
|
120 |
-
parser.add_argument(
|
121 |
-
"--chunk_length_s", type=float, default=None, help="Chunk length in seconds. Defaults to 5 seconds."
|
122 |
-
)
|
123 |
-
parser.add_argument(
|
124 |
-
"--stride_length_s", type=float, default=None, help="Stride of the audio chunks. Defaults to 1 second."
|
125 |
-
)
|
126 |
-
parser.add_argument(
|
127 |
-
"--log_outputs", action="store_true", help="If defined, write outputs to log file for analysis."
|
128 |
-
)
|
129 |
-
parser.add_argument(
|
130 |
-
"--device",
|
131 |
-
type=int,
|
132 |
-
default=None,
|
133 |
-
help="The device to run the pipeline on. -1 for CPU (default), 0 for the first GPU and so on.",
|
134 |
-
)
|
135 |
-
args = parser.parse_args()
|
136 |
-
|
137 |
-
main(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.ipynb_checkpoints/eval_results-checkpoint.json
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"epoch": 2.5,
|
3 |
-
"eval_loss": 0.13549980521202087,
|
4 |
-
"eval_runtime": 1012.9633,
|
5 |
-
"eval_samples": 15995,
|
6 |
-
"eval_samples_per_second": 15.79,
|
7 |
-
"eval_steps_per_second": 1.974,
|
8 |
-
"eval_wer": 0.15316115109878853
|
9 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.ipynb_checkpoints/run_training-checkpoint.sh
CHANGED
@@ -3,10 +3,11 @@ python run_speech_recognition_ctc.py \
|
|
3 |
--model_name_or_path="facebook/wav2vec2-xls-r-1b" \
|
4 |
--dataset_config_name="de" \
|
5 |
--output_dir="./" \
|
6 |
-
|
7 |
-
--
|
|
|
8 |
--per_device_eval_batch_size="8" \
|
9 |
-
--gradient_accumulation_steps="
|
10 |
--learning_rate="7.5e-5" \
|
11 |
--warmup_steps="2000" \
|
12 |
--length_column_name="input_length" \
|
|
|
3 |
--model_name_or_path="facebook/wav2vec2-xls-r-1b" \
|
4 |
--dataset_config_name="de" \
|
5 |
--output_dir="./" \
|
6 |
+
--overwrite_output_dir \
|
7 |
+
--num_train_epochs="2.5" \
|
8 |
+
--per_device_train_batch_size="8" \
|
9 |
--per_device_eval_batch_size="8" \
|
10 |
+
--gradient_accumulation_steps="4" \
|
11 |
--learning_rate="7.5e-5" \
|
12 |
--warmup_steps="2000" \
|
13 |
--length_column_name="input_length" \
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 3850681649
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:adab4f8eb5a7f62adcf02be9f38b9ff41939347692ddf7d2d0fc29a4f749467b
|
3 |
size 3850681649
|
run_training.sh
CHANGED
@@ -3,10 +3,11 @@ python run_speech_recognition_ctc.py \
|
|
3 |
--model_name_or_path="facebook/wav2vec2-xls-r-1b" \
|
4 |
--dataset_config_name="de" \
|
5 |
--output_dir="./" \
|
6 |
-
|
7 |
-
--
|
|
|
8 |
--per_device_eval_batch_size="8" \
|
9 |
-
--gradient_accumulation_steps="
|
10 |
--learning_rate="7.5e-5" \
|
11 |
--warmup_steps="2000" \
|
12 |
--length_column_name="input_length" \
|
|
|
3 |
--model_name_or_path="facebook/wav2vec2-xls-r-1b" \
|
4 |
--dataset_config_name="de" \
|
5 |
--output_dir="./" \
|
6 |
+
--overwrite_output_dir \
|
7 |
+
--num_train_epochs="2.5" \
|
8 |
+
--per_device_train_batch_size="8" \
|
9 |
--per_device_eval_batch_size="8" \
|
10 |
+
--gradient_accumulation_steps="4" \
|
11 |
--learning_rate="7.5e-5" \
|
12 |
--warmup_steps="2000" \
|
13 |
--length_column_name="input_length" \
|
special_tokens_map.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]", "additional_special_tokens": [{"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}
|
|
|
1 |
+
{"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "pad_token": "[PAD]", "additional_special_tokens": [{"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, {"content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}]}
|
training_args.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 2991
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d671fb0f181e146452d1d68a46c3b54df59aa573465bc6cf0a59cb0e02b849a
|
3 |
size 2991
|