Update README.md
Browse filesupdate evaluation using bert/T-TA and evaluation time
README.md
CHANGED
@@ -21,7 +21,7 @@ model-index:
|
|
21 |
metrics:
|
22 |
- name: Test CER
|
23 |
type: cer
|
24 |
-
value:
|
25 |
---
|
26 |
|
27 |
# Wav2Vec2-Large-XLSR-53-tw-gpt
|
@@ -48,7 +48,8 @@ model_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
|
|
48 |
device = "cuda"
|
49 |
processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
|
50 |
|
51 |
-
chars_to_ignore_regex = r"[¥•"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'
|
|
|
52 |
|
53 |
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
|
54 |
processor = Wav2Vec2Processor.from_pretrained(processor_name)
|
@@ -95,10 +96,17 @@ predict(load_file_to_data('voice file path'))
|
|
95 |
The model can be evaluated as follows on the zh-tw test data of Common Voice.
|
96 |
CER calculation refer to https://huggingface.co/ctl/wav2vec2-large-xlsr-cantonese
|
97 |
|
98 |
-
|
|
|
99 |
!mkdir cer
|
|
|
100 |
!pip install jiwer
|
|
|
|
|
|
|
101 |
|
|
|
|
|
102 |
import torchaudio
|
103 |
from datasets import load_dataset, load_metric
|
104 |
from transformers import (
|
@@ -113,7 +121,8 @@ model_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
|
|
113 |
device = "cuda"
|
114 |
processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
|
115 |
|
116 |
-
chars_to_ignore_regex = r"[¥•"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'
|
|
|
117 |
|
118 |
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
|
119 |
processor = Wav2Vec2Processor.from_pretrained(processor_name)
|
@@ -148,14 +157,11 @@ cer = load_metric("./cer")
|
|
148 |
print("CER: {:2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["target"])))
|
149 |
```
|
150 |
|
151 |
-
`CER: 28.
|
|
|
152 |
|
153 |
## Evaluation with GPT:
|
154 |
```python
|
155 |
-
!mkdir cer
|
156 |
-
!wget -O cer/cer.py https://huggingface.co/ctl/wav2vec2-large-xlsr-cantonese/raw/main/cer.py
|
157 |
-
!pip install jiwer
|
158 |
-
|
159 |
import torchaudio
|
160 |
from datasets import load_dataset, load_metric
|
161 |
from transformers import (
|
@@ -170,10 +176,10 @@ from transformers import AutoTokenizer, AutoModelWithLMHead
|
|
170 |
model_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
|
171 |
device = "cuda"
|
172 |
processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
|
173 |
-
chars_to_ignore_regex = r"
|
174 |
|
175 |
tokenizer = AutoTokenizer.from_pretrained("ckiplab/gpt2-base-chinese")
|
176 |
-
|
177 |
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
|
178 |
processor = Wav2Vec2Processor.from_pretrained(processor_name)
|
179 |
|
@@ -196,18 +202,173 @@ def map_to_pred(batch):
|
|
196 |
attention_mask = features.attention_mask.to(device)
|
197 |
with torch.no_grad():
|
198 |
logits = model(input_values, attention_mask=attention_mask).logits
|
199 |
-
|
200 |
decoded_results = []
|
201 |
for logit in logits:
|
202 |
pred_ids = torch.argmax(logit, dim=-1)
|
203 |
mask = pred_ids.ge(1).unsqueeze(-1).expand(logit.size())
|
204 |
vocab_size = logit.size()[-1]
|
205 |
voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
|
206 |
-
|
207 |
-
|
208 |
-
comb_pred_ids = torch.argmax(
|
209 |
decoded_results.append(processor.decode(comb_pred_ids))
|
210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
batch["predicted"] = decoded_results
|
212 |
batch["target"] = batch["sentence"]
|
213 |
return batch
|
@@ -219,4 +380,5 @@ cer = load_metric("./cer")
|
|
219 |
print("CER: {:2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["target"])))
|
220 |
```
|
221 |
|
222 |
-
`CER 25.
|
|
|
|
21 |
metrics:
|
22 |
- name: Test CER
|
23 |
type: cer
|
24 |
+
value: 25.57
|
25 |
---
|
26 |
|
27 |
# Wav2Vec2-Large-XLSR-53-tw-gpt
|
|
|
48 |
device = "cuda"
|
49 |
processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
|
50 |
|
51 |
+
chars_to_ignore_regex = r"[¥•"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·.﹑︰〈〉─《﹖﹣﹂﹁﹔!?。。"#$%&'()*+,﹐-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏..!\"#$%&()*+,\-.\:;<=>?@\[\]\\\/^_`{|}~]"
|
52 |
+
|
53 |
|
54 |
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
|
55 |
processor = Wav2Vec2Processor.from_pretrained(processor_name)
|
|
|
96 |
The model can be evaluated as follows on the zh-tw test data of Common Voice.
|
97 |
CER calculation refer to https://huggingface.co/ctl/wav2vec2-large-xlsr-cantonese
|
98 |
|
99 |
+
env setup:
|
100 |
+
```
|
101 |
!mkdir cer
|
102 |
+
!wget -O cer/cer.py https://huggingface.co/ctl/wav2vec2-large-xlsr-cantonese/raw/main/cer.py
|
103 |
!pip install jiwer
|
104 |
+
!pip install torchaudio
|
105 |
+
!pip install datasets transformers
|
106 |
+
```
|
107 |
|
108 |
+
## Evaluation without LM:
|
109 |
+
```python
|
110 |
import torchaudio
|
111 |
from datasets import load_dataset, load_metric
|
112 |
from transformers import (
|
|
|
121 |
device = "cuda"
|
122 |
processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
|
123 |
|
124 |
+
chars_to_ignore_regex = r"[¥•"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·.﹑︰〈〉─《﹖﹣﹂﹁﹔!?。。"#$%&'()*+,﹐-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏..!\"#$%&()*+,\-.\:;<=>?@\[\]\\\/^_`{|}~]"
|
125 |
+
|
126 |
|
127 |
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
|
128 |
processor = Wav2Vec2Processor.from_pretrained(processor_name)
|
|
|
157 |
print("CER: {:2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["target"])))
|
158 |
```
|
159 |
|
160 |
+
`CER: 28.79`.
|
161 |
+
`TIME: 05:23 min`
|
162 |
|
163 |
## Evaluation with GPT:
|
164 |
```python
|
|
|
|
|
|
|
|
|
165 |
import torchaudio
|
166 |
from datasets import load_dataset, load_metric
|
167 |
from transformers import (
|
|
|
176 |
model_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
|
177 |
device = "cuda"
|
178 |
processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
|
179 |
+
chars_to_ignore_regex = r"[¥•���#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·.﹑︰〈〉─《﹖﹣﹂﹁﹔!?。。"#$%&'()*+,﹐-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏..!\"#$%&()*+,\-.\:;<=>?@\[\]\\\/^_`{|}~]"
|
180 |
|
181 |
tokenizer = AutoTokenizer.from_pretrained("ckiplab/gpt2-base-chinese")
|
182 |
+
lm_model = AutoModelWithLMHead.from_pretrained("ckiplab/gpt2-base-chinese").to(device)
|
183 |
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
|
184 |
processor = Wav2Vec2Processor.from_pretrained(processor_name)
|
185 |
|
|
|
202 |
attention_mask = features.attention_mask.to(device)
|
203 |
with torch.no_grad():
|
204 |
logits = model(input_values, attention_mask=attention_mask).logits
|
205 |
+
|
206 |
decoded_results = []
|
207 |
for logit in logits:
|
208 |
pred_ids = torch.argmax(logit, dim=-1)
|
209 |
mask = pred_ids.ge(1).unsqueeze(-1).expand(logit.size())
|
210 |
vocab_size = logit.size()[-1]
|
211 |
voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
|
212 |
+
lm_input = torch.cat((torch.tensor([tokenizer.cls_token_id]).to(device),pred_ids[pred_ids>0]), 0)
|
213 |
+
lm_prob = torch.nn.functional.softmax(lm_model(lm_input).logits, dim=-1)[:voice_prob.size()[0],:]
|
214 |
+
comb_pred_ids = torch.argmax(lm_prob*voice_prob, dim=-1)
|
215 |
decoded_results.append(processor.decode(comb_pred_ids))
|
216 |
+
|
217 |
+
batch["predicted"] = decoded_results
|
218 |
+
batch["target"] = batch["sentence"]
|
219 |
+
return batch
|
220 |
+
|
221 |
+
|
222 |
+
result = ds.map(map_to_pred, batched=True, batch_size=16, remove_columns=list(ds.features.keys()))
|
223 |
+
|
224 |
+
cer = load_metric("./cer")
|
225 |
+
print("CER: {:2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["target"])))
|
226 |
+
```
|
227 |
+
|
228 |
+
`CER 25.75`.
|
229 |
+
`TIME: 06:04 min`
|
230 |
+
|
231 |
+
## Evaluation with BERT:
|
232 |
+
```python
|
233 |
+
import torchaudio
|
234 |
+
from datasets import load_dataset, load_metric
|
235 |
+
from transformers import (
|
236 |
+
Wav2Vec2ForCTC,
|
237 |
+
Wav2Vec2Processor,
|
238 |
+
)
|
239 |
+
import torch
|
240 |
+
import re
|
241 |
+
import sys
|
242 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
243 |
+
|
244 |
+
model_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
|
245 |
+
device = "cuda"
|
246 |
+
processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
|
247 |
+
chars_to_ignore_regex = r"[¥•"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·.﹑︰〈〉─《﹖﹣﹂﹁﹔!?。。"#$%&'()*+,﹐-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏..!\"#$%&()*+,\-.\:;<=>?@\[\]\\\/^_`{|}~]"
|
248 |
+
|
249 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
250 |
+
lm_model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese").to(device)
|
251 |
+
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
|
252 |
+
processor = Wav2Vec2Processor.from_pretrained(processor_name)
|
253 |
+
|
254 |
+
ds = load_dataset("common_voice", 'zh-TW', data_dir="./cv-corpus-6.1-2020-12-11", split="test")
|
255 |
+
|
256 |
+
resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)
|
257 |
+
|
258 |
+
def map_to_array(batch):
|
259 |
+
speech, _ = torchaudio.load(batch["path"])
|
260 |
+
batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
|
261 |
+
batch["sampling_rate"] = resampler.new_freq
|
262 |
+
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
|
263 |
+
return batch
|
264 |
+
|
265 |
+
ds = ds.map(map_to_array)
|
266 |
+
|
267 |
+
def map_to_pred(batch):
|
268 |
+
features = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0], padding=True, return_tensors="pt")
|
269 |
+
input_values = features.input_values.to(device)
|
270 |
+
attention_mask = features.attention_mask.to(device)
|
271 |
+
with torch.no_grad():
|
272 |
+
logits = model(input_values, attention_mask=attention_mask).logits
|
273 |
+
|
274 |
+
decoded_results = []
|
275 |
+
for logit in logits:
|
276 |
+
pred_ids = torch.argmax(logit, dim=-1)
|
277 |
+
mask = ~pred_ids.eq(tokenizer.pad_token_id).unsqueeze(-1).expand(logit.size())
|
278 |
+
vocab_size = logit.size()[-1]
|
279 |
+
voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
|
280 |
+
lm_input = torch.masked_select(pred_ids, ~pred_ids.eq(tokenizer.pad_token_id)).unsqueeze(0)
|
281 |
+
mask_lm_prob = voice_prob.clone()
|
282 |
+
for i in range(lm_input.shape[-1]):
|
283 |
+
masked_lm_input = lm_input.clone()
|
284 |
+
masked_lm_input[0][i] = torch.tensor(tokenizer.mask_token_id).to('cuda')
|
285 |
+
lm_prob = torch.nn.functional.softmax(lm_model(masked_lm_input).logits, dim=-1).squeeze(0)
|
286 |
+
mask_lm_prob[i] = lm_prob[i]
|
287 |
+
comb_pred_ids = torch.argmax(mask_lm_prob*voice_prob, dim=-1)
|
288 |
+
decoded_results.append(processor.decode(comb_pred_ids))
|
289 |
+
|
290 |
+
batch["predicted"] = decoded_results
|
291 |
+
batch["target"] = batch["sentence"]
|
292 |
+
return batch
|
293 |
+
|
294 |
+
|
295 |
+
result = ds.map(map_to_pred, batched=True, batch_size=1, remove_columns=list(ds.features.keys()))
|
296 |
+
|
297 |
+
cer = load_metric("./cer")
|
298 |
+
print("CER: {:2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["target"])))
|
299 |
+
```
|
300 |
+
`CER 25.57`.
|
301 |
+
`TIME: 09:49 min`
|
302 |
+
|
303 |
+
## Evaluation with T-TA:
|
304 |
+
setup
|
305 |
+
```
|
306 |
+
!git clone https://github.com/voidful/pytorch-tta.git
|
307 |
+
!mv ./pytorch-tta/tta ./tta
|
308 |
+
!wget https://github.com/voidful/pytorch-tta/releases/download/wiki_zh/wiki_zh.pt
|
309 |
+
```
|
310 |
+
|
311 |
+
```python
|
312 |
+
import torchaudio
|
313 |
+
from datasets import load_dataset, load_metric
|
314 |
+
from transformers import (
|
315 |
+
Wav2Vec2ForCTC,
|
316 |
+
Wav2Vec2Processor,
|
317 |
+
)
|
318 |
+
import torch
|
319 |
+
import re
|
320 |
+
import sys
|
321 |
+
from tta.modeling_tta import TTALMModel
|
322 |
+
from transformers import AutoTokenizer
|
323 |
+
import torch
|
324 |
+
|
325 |
+
|
326 |
+
|
327 |
+
model_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
|
328 |
+
device = "cuda"
|
329 |
+
processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
|
330 |
+
chars_to_ignore_regex = r"[¥•"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·.﹑︰〈〉─《﹖﹣﹂﹁﹔!?。。"#$%&'()*+,﹐-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏..!\"#$%&()*+,\-.\:;<=>?@\[\]\\\/^_`{|}~]"
|
331 |
+
|
332 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
333 |
+
lm_model = TTALMModel("bert-base-chinese")
|
334 |
+
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
|
335 |
+
lm_model.load_state_dict(torch.load("./wiki_zh.pt",map_location=torch.device('cuda')))
|
336 |
+
lm_model.to('cuda')
|
337 |
+
lm_model.eval()
|
338 |
+
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
|
339 |
+
processor = Wav2Vec2Processor.from_pretrained(processor_name)
|
340 |
+
|
341 |
+
ds = load_dataset("common_voice", 'zh-TW', data_dir="./cv-corpus-6.1-2020-12-11", split="test")
|
342 |
+
|
343 |
+
resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)
|
344 |
+
|
345 |
+
def map_to_array(batch):
|
346 |
+
speech, _ = torchaudio.load(batch["path"])
|
347 |
+
batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
|
348 |
+
batch["sampling_rate"] = resampler.new_freq
|
349 |
+
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
|
350 |
+
return batch
|
351 |
+
|
352 |
+
ds = ds.map(map_to_array)
|
353 |
+
|
354 |
+
def map_to_pred(batch):
|
355 |
+
features = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0], padding=True, return_tensors="pt")
|
356 |
+
input_values = features.input_values.to(device)
|
357 |
+
attention_mask = features.attention_mask.to(device)
|
358 |
+
with torch.no_grad():
|
359 |
+
logits = model(input_values, attention_mask=attention_mask).logits
|
360 |
+
|
361 |
+
decoded_results = []
|
362 |
+
for logit in logits:
|
363 |
+
pred_ids = torch.argmax(logit, dim=-1)
|
364 |
+
mask = ~pred_ids.eq(tokenizer.pad_token_id).unsqueeze(-1).expand(logit.size())
|
365 |
+
vocab_size = logit.size()[-1]
|
366 |
+
voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
|
367 |
+
lm_input = torch.masked_select(pred_ids, ~pred_ids.eq(tokenizer.pad_token_id)).unsqueeze(0)
|
368 |
+
lm_prob = torch.nn.functional.softmax(lm_model.forward(lm_input)[0], dim=-1).squeeze(0)
|
369 |
+
comb_pred_ids = torch.argmax(lm_prob*voice_prob, dim=-1)
|
370 |
+
decoded_results.append(processor.decode(comb_pred_ids))
|
371 |
+
|
372 |
batch["predicted"] = decoded_results
|
373 |
batch["target"] = batch["sentence"]
|
374 |
return batch
|
|
|
380 |
print("CER: {:2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["target"])))
|
381 |
```
|
382 |
|
383 |
+
`CER: 25.77`.
|
384 |
+
`TIME: 06:01 min`
|