correcting CER score
Browse files
README.md
CHANGED
@@ -21,10 +21,14 @@ model-index:
|
|
21 |
metrics:
|
22 |
- name: Test CER
|
23 |
type: cer
|
24 |
-
value:
|
25 |
---
|
26 |
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
[Colab trial](https://colab.research.google.com/drive/1nBRLf4Pwiply_y5rXWoaIB8LxX41tfEI?usp=sharing)
|
29 |
|
30 |
```
|
@@ -73,8 +77,15 @@ Predict
|
|
73 |
predict(load_file_to_data('voice file path'))
|
74 |
```
|
75 |
|
76 |
-
## Evaluation
|
|
|
|
|
|
|
77 |
```python
|
|
|
|
|
|
|
|
|
78 |
import torchaudio
|
79 |
from datasets import load_dataset, load_metric
|
80 |
from transformers import (
|
@@ -85,6 +96,7 @@ import torch
|
|
85 |
import re
|
86 |
import sys
|
87 |
|
|
|
88 |
model_name = "voidful/wav2vec2-large-xlsr-53-hk"
|
89 |
device = "cuda"
|
90 |
processor_name = "voidful/wav2vec2-large-xlsr-53-hk"
|
@@ -94,7 +106,7 @@ chars_to_ignore_regex = r"[¥•"#$%&'()*+,-/:;<
|
|
94 |
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
|
95 |
processor = Wav2Vec2Processor.from_pretrained(processor_name)
|
96 |
|
97 |
-
ds = load_dataset("common_voice", 'zh-HK', split="test")
|
98 |
|
99 |
resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)
|
100 |
|
@@ -120,9 +132,7 @@ def map_to_pred(batch):
|
|
120 |
|
121 |
result = ds.map(map_to_pred, batched=True, batch_size=16, remove_columns=list(ds.features.keys()))
|
122 |
|
123 |
-
|
124 |
-
|
125 |
-
print(wer.compute(predictions=result["predicted"], references=result["target"]))
|
126 |
```
|
127 |
|
128 |
-
`CER
|
|
|
21 |
metrics:
|
22 |
- name: Test CER
|
23 |
type: cer
|
24 |
+
value: 16.41
|
25 |
---
|
26 |
|
27 |
+
# Wav2Vec2-Large-XLSR-53-hk
|
28 |
+
Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on Cantonese using the [Common Voice](https://huggingface.co/datasets/common_voice).
|
29 |
+
When using this model, make sure that your speech input is sampled at 16kHz.
|
30 |
+
|
31 |
+
## Usage
|
32 |
[Colab trial](https://colab.research.google.com/drive/1nBRLf4Pwiply_y5rXWoaIB8LxX41tfEI?usp=sharing)
|
33 |
|
34 |
```
|
|
|
77 |
predict(load_file_to_data('voice file path'))
|
78 |
```
|
79 |
|
80 |
+
## Evaluation
|
81 |
+
The model can be evaluated as follows on the Chinese (Hong Kong) test data of Common Voice.
|
82 |
+
CER calculation refer to https://huggingface.co/ctl/wav2vec2-large-xlsr-cantonese
|
83 |
+
|
84 |
```python
|
85 |
+
!mkdir cer
|
86 |
+
!wget -O cer/cer.py https://huggingface.co/ctl/wav2vec2-large-xlsr-cantonese/raw/main/cer.py
|
87 |
+
!pip install jiwer
|
88 |
+
|
89 |
import torchaudio
|
90 |
from datasets import load_dataset, load_metric
|
91 |
from transformers import (
|
|
|
96 |
import re
|
97 |
import sys
|
98 |
|
99 |
+
cer = load_metric("./cer")
|
100 |
model_name = "voidful/wav2vec2-large-xlsr-53-hk"
|
101 |
device = "cuda"
|
102 |
processor_name = "voidful/wav2vec2-large-xlsr-53-hk"
|
|
|
106 |
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
|
107 |
processor = Wav2Vec2Processor.from_pretrained(processor_name)
|
108 |
|
109 |
+
ds = load_dataset("common_voice", 'zh-HK', data_dir="./cv-corpus-6.1-2020-12-11", split="test")
|
110 |
|
111 |
resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)
|
112 |
|
|
|
132 |
|
133 |
result = ds.map(map_to_pred, batched=True, batch_size=16, remove_columns=list(ds.features.keys()))
|
134 |
|
135 |
+
print("CER: {:2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["target"])))
|
|
|
|
|
136 |
```
|
137 |
|
138 |
+
`CER 16.41`
|