voidful commited on
Commit
4d24677
1 Parent(s): 202a8b5

Update README.md

Browse files

update evaluation using bert/T-TA and evaluation time

Files changed (1) hide show
  1. README.md +179 -17
README.md CHANGED
@@ -21,7 +21,7 @@ model-index:
21
  metrics:
22
  - name: Test CER
23
  type: cer
24
- value: 16.41
25
  ---
26
 
27
  # Wav2Vec2-Large-XLSR-53-tw-gpt
@@ -48,7 +48,8 @@ model_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
48
  device = "cuda"
49
  processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
50
 
51
- chars_to_ignore_regex = r"[¥•"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·.﹑︰〈〉─《﹖﹣﹂﹁﹔!?。。"#$%&'()*+,﹐-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏..!\\\\\\\\"#$%&()*+,\\\\\\\\-.\\\\\\\\:;<=>?@\\\\\\\\[\\\\\\\\]\\\\\\\\\\\\\\\\\\\\\\\\/^_`{|}~]"
 
52
 
53
  model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
54
  processor = Wav2Vec2Processor.from_pretrained(processor_name)
@@ -95,10 +96,17 @@ predict(load_file_to_data('voice file path'))
95
  The model can be evaluated as follows on the zh-tw test data of Common Voice.
96
  CER calculation refer to https://huggingface.co/ctl/wav2vec2-large-xlsr-cantonese
97
 
98
- ```python
 
99
  !mkdir cer
 
100
  !pip install jiwer
 
 
 
101
 
 
 
102
  import torchaudio
103
  from datasets import load_dataset, load_metric
104
  from transformers import (
@@ -113,7 +121,8 @@ model_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
113
  device = "cuda"
114
  processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
115
 
116
- chars_to_ignore_regex = r"[¥•"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·.﹑︰〈〉─《﹖﹣﹂﹁﹔!?。。"#$%&'()*+,﹐-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏..!\\\\\\\\"#$%&()*+,\\\\\\\\-.\\\\\\\\:;<=>?@\\\\\\\\[\\\\\\\\]\\\\\\\\\\\\\\\\\\\\\\\\/^_`{|}~]"
 
117
 
118
  model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
119
  processor = Wav2Vec2Processor.from_pretrained(processor_name)
@@ -148,14 +157,11 @@ cer = load_metric("./cer")
148
  print("CER: {:2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["target"])))
149
  ```
150
 
151
- `CER: 28.734822`
 
152
 
153
  ## Evaluation with GPT:
154
  ```python
155
- !mkdir cer
156
- !wget -O cer/cer.py https://huggingface.co/ctl/wav2vec2-large-xlsr-cantonese/raw/main/cer.py
157
- !pip install jiwer
158
-
159
  import torchaudio
160
  from datasets import load_dataset, load_metric
161
  from transformers import (
@@ -170,10 +176,10 @@ from transformers import AutoTokenizer, AutoModelWithLMHead
170
  model_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
171
  device = "cuda"
172
  processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
173
- chars_to_ignore_regex = r"""[¥•"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·.﹑︰〈〉─《﹖﹣﹂﹁﹔!?。。"#$%&'()*+,﹐-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏..!\\\\\\\\"#$%&()*+,\\\\\\\\-.\\\\\\\\:;<=>?@\\\\\\\\[\\\\\\\\]\\\\\\\\\\\\\\\\\\\\\\\\/^_`{|}~]"""
174
 
175
  tokenizer = AutoTokenizer.from_pretrained("ckiplab/gpt2-base-chinese")
176
- gpt_model = AutoModelWithLMHead.from_pretrained("ckiplab/gpt2-base-chinese").to(device)
177
  model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
178
  processor = Wav2Vec2Processor.from_pretrained(processor_name)
179
 
@@ -196,18 +202,173 @@ def map_to_pred(batch):
196
  attention_mask = features.attention_mask.to(device)
197
  with torch.no_grad():
198
  logits = model(input_values, attention_mask=attention_mask).logits
199
-
200
  decoded_results = []
201
  for logit in logits:
202
  pred_ids = torch.argmax(logit, dim=-1)
203
  mask = pred_ids.ge(1).unsqueeze(-1).expand(logit.size())
204
  vocab_size = logit.size()[-1]
205
  voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
206
- gpt_input = torch.cat((torch.tensor([tokenizer.cls_token_id]).to(device),pred_ids[pred_ids>0]), 0)
207
- gpt_prob = torch.nn.functional.softmax(gpt_model(gpt_input).logits, dim=-1)[:voice_prob.size()[0],:]
208
- comb_pred_ids = torch.argmax(gpt_prob*voice_prob, dim=-1)
209
  decoded_results.append(processor.decode(comb_pred_ids))
210
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  batch["predicted"] = decoded_results
212
  batch["target"] = batch["sentence"]
213
  return batch
@@ -219,4 +380,5 @@ cer = load_metric("./cer")
219
  print("CER: {:2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["target"])))
220
  ```
221
 
222
- `CER 25.69`
 
 
21
  metrics:
22
  - name: Test CER
23
  type: cer
24
+ value: 25.57
25
  ---
26
 
27
  # Wav2Vec2-Large-XLSR-53-tw-gpt
 
48
  device = "cuda"
49
  processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
50
 
51
+ chars_to_ignore_regex = r"[¥•"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·.﹑︰〈〉─《﹖﹣﹂﹁﹔!?。。"#$%&'()*+,﹐-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏..!\"#$%&()*+,\-.\:;<=>?@\[\]\\\/^_`{|}~]"
52
+
53
 
54
  model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
55
  processor = Wav2Vec2Processor.from_pretrained(processor_name)
 
96
  The model can be evaluated as follows on the zh-tw test data of Common Voice.
97
  CER calculation refer to https://huggingface.co/ctl/wav2vec2-large-xlsr-cantonese
98
 
99
+ env setup:
100
+ ```
101
  !mkdir cer
102
+ !wget -O cer/cer.py https://huggingface.co/ctl/wav2vec2-large-xlsr-cantonese/raw/main/cer.py
103
  !pip install jiwer
104
+ !pip install torchaudio
105
+ !pip install datasets transformers
106
+ ```
107
 
108
+ ## Evaluation without LM:
109
+ ```python
110
  import torchaudio
111
  from datasets import load_dataset, load_metric
112
  from transformers import (
 
121
  device = "cuda"
122
  processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
123
 
124
+ chars_to_ignore_regex = r"[¥•"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·.﹑︰〈〉─《﹖﹣﹂﹁﹔!?。。"#$%&'()*+,﹐-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏..!\"#$%&()*+,\-.\:;<=>?@\[\]\\\/^_`{|}~]"
125
+
126
 
127
  model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
128
  processor = Wav2Vec2Processor.from_pretrained(processor_name)
 
157
  print("CER: {:2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["target"])))
158
  ```
159
 
160
+ `CER: 28.79`.
161
+ `TIME: 05:23 min`
162
 
163
  ## Evaluation with GPT:
164
  ```python
 
 
 
 
165
  import torchaudio
166
  from datasets import load_dataset, load_metric
167
  from transformers import (
 
176
  model_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
177
  device = "cuda"
178
  processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
179
+ chars_to_ignore_regex = r"[¥•���#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·.﹑︰〈〉─《﹖﹣﹂﹁﹔!?。。"#$%&'()*+,﹐-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏..!\"#$%&()*+,\-.\:;<=>?@\[\]\\\/^_`{|}~]"
180
 
181
  tokenizer = AutoTokenizer.from_pretrained("ckiplab/gpt2-base-chinese")
182
+ lm_model = AutoModelWithLMHead.from_pretrained("ckiplab/gpt2-base-chinese").to(device)
183
  model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
184
  processor = Wav2Vec2Processor.from_pretrained(processor_name)
185
 
 
202
  attention_mask = features.attention_mask.to(device)
203
  with torch.no_grad():
204
  logits = model(input_values, attention_mask=attention_mask).logits
205
+
206
  decoded_results = []
207
  for logit in logits:
208
  pred_ids = torch.argmax(logit, dim=-1)
209
  mask = pred_ids.ge(1).unsqueeze(-1).expand(logit.size())
210
  vocab_size = logit.size()[-1]
211
  voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
212
+ lm_input = torch.cat((torch.tensor([tokenizer.cls_token_id]).to(device),pred_ids[pred_ids>0]), 0)
213
+ lm_prob = torch.nn.functional.softmax(lm_model(lm_input).logits, dim=-1)[:voice_prob.size()[0],:]
214
+ comb_pred_ids = torch.argmax(lm_prob*voice_prob, dim=-1)
215
  decoded_results.append(processor.decode(comb_pred_ids))
216
+
217
+ batch["predicted"] = decoded_results
218
+ batch["target"] = batch["sentence"]
219
+ return batch
220
+
221
+
222
+ result = ds.map(map_to_pred, batched=True, batch_size=16, remove_columns=list(ds.features.keys()))
223
+
224
+ cer = load_metric("./cer")
225
+ print("CER: {:2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["target"])))
226
+ ```
227
+
228
+ `CER 25.75`.
229
+ `TIME: 06:04 min`
230
+
231
+ ## Evaluation with BERT:
232
+ ```python
233
+ import torchaudio
234
+ from datasets import load_dataset, load_metric
235
+ from transformers import (
236
+ Wav2Vec2ForCTC,
237
+ Wav2Vec2Processor,
238
+ )
239
+ import torch
240
+ import re
241
+ import sys
242
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
243
+
244
+ model_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
245
+ device = "cuda"
246
+ processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
247
+ chars_to_ignore_regex = r"[¥•"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·.﹑︰〈〉─《﹖﹣﹂﹁﹔!?。。"#$%&'()*+,﹐-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏..!\"#$%&()*+,\-.\:;<=>?@\[\]\\\/^_`{|}~]"
248
+
249
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
250
+ lm_model = AutoModelForMaskedLM.from_pretrained("bert-base-chinese").to(device)
251
+ model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
252
+ processor = Wav2Vec2Processor.from_pretrained(processor_name)
253
+
254
+ ds = load_dataset("common_voice", 'zh-TW', data_dir="./cv-corpus-6.1-2020-12-11", split="test")
255
+
256
+ resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)
257
+
258
+ def map_to_array(batch):
259
+ speech, _ = torchaudio.load(batch["path"])
260
+ batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
261
+ batch["sampling_rate"] = resampler.new_freq
262
+ batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
263
+ return batch
264
+
265
+ ds = ds.map(map_to_array)
266
+
267
+ def map_to_pred(batch):
268
+ features = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0], padding=True, return_tensors="pt")
269
+ input_values = features.input_values.to(device)
270
+ attention_mask = features.attention_mask.to(device)
271
+ with torch.no_grad():
272
+ logits = model(input_values, attention_mask=attention_mask).logits
273
+
274
+ decoded_results = []
275
+ for logit in logits:
276
+ pred_ids = torch.argmax(logit, dim=-1)
277
+ mask = ~pred_ids.eq(tokenizer.pad_token_id).unsqueeze(-1).expand(logit.size())
278
+ vocab_size = logit.size()[-1]
279
+ voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
280
+ lm_input = torch.masked_select(pred_ids, ~pred_ids.eq(tokenizer.pad_token_id)).unsqueeze(0)
281
+ mask_lm_prob = voice_prob.clone()
282
+ for i in range(lm_input.shape[-1]):
283
+ masked_lm_input = lm_input.clone()
284
+ masked_lm_input[0][i] = torch.tensor(tokenizer.mask_token_id).to('cuda')
285
+ lm_prob = torch.nn.functional.softmax(lm_model(masked_lm_input).logits, dim=-1).squeeze(0)
286
+ mask_lm_prob[i] = lm_prob[i]
287
+ comb_pred_ids = torch.argmax(mask_lm_prob*voice_prob, dim=-1)
288
+ decoded_results.append(processor.decode(comb_pred_ids))
289
+
290
+ batch["predicted"] = decoded_results
291
+ batch["target"] = batch["sentence"]
292
+ return batch
293
+
294
+
295
+ result = ds.map(map_to_pred, batched=True, batch_size=1, remove_columns=list(ds.features.keys()))
296
+
297
+ cer = load_metric("./cer")
298
+ print("CER: {:2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["target"])))
299
+ ```
300
+ `CER 25.57`.
301
+ `TIME: 09:49 min`
302
+
303
+ ## Evaluation with T-TA:
304
+ setup
305
+ ```
306
+ !git clone https://github.com/voidful/pytorch-tta.git
307
+ !mv ./pytorch-tta/tta ./tta
308
+ !wget https://github.com/voidful/pytorch-tta/releases/download/wiki_zh/wiki_zh.pt
309
+ ```
310
+
311
+ ```python
312
+ import torchaudio
313
+ from datasets import load_dataset, load_metric
314
+ from transformers import (
315
+ Wav2Vec2ForCTC,
316
+ Wav2Vec2Processor,
317
+ )
318
+ import torch
319
+ import re
320
+ import sys
321
+ from tta.modeling_tta import TTALMModel
322
+ from transformers import AutoTokenizer
323
+ import torch
324
+
325
+
326
+
327
+ model_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
328
+ device = "cuda"
329
+ processor_name = "voidful/wav2vec2-large-xlsr-53-tw-gpt"
330
+ chars_to_ignore_regex = r"[¥•"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·'℃°•·.﹑︰〈〉─《﹖﹣﹂﹁﹔!?。。"#$%&'()*+,﹐-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏..!\"#$%&()*+,\-.\:;<=>?@\[\]\\\/^_`{|}~]"
331
+
332
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
333
+ lm_model = TTALMModel("bert-base-chinese")
334
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
335
+ lm_model.load_state_dict(torch.load("./wiki_zh.pt",map_location=torch.device('cuda')))
336
+ lm_model.to('cuda')
337
+ lm_model.eval()
338
+ model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
339
+ processor = Wav2Vec2Processor.from_pretrained(processor_name)
340
+
341
+ ds = load_dataset("common_voice", 'zh-TW', data_dir="./cv-corpus-6.1-2020-12-11", split="test")
342
+
343
+ resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)
344
+
345
+ def map_to_array(batch):
346
+ speech, _ = torchaudio.load(batch["path"])
347
+ batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
348
+ batch["sampling_rate"] = resampler.new_freq
349
+ batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
350
+ return batch
351
+
352
+ ds = ds.map(map_to_array)
353
+
354
+ def map_to_pred(batch):
355
+ features = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0], padding=True, return_tensors="pt")
356
+ input_values = features.input_values.to(device)
357
+ attention_mask = features.attention_mask.to(device)
358
+ with torch.no_grad():
359
+ logits = model(input_values, attention_mask=attention_mask).logits
360
+
361
+ decoded_results = []
362
+ for logit in logits:
363
+ pred_ids = torch.argmax(logit, dim=-1)
364
+ mask = ~pred_ids.eq(tokenizer.pad_token_id).unsqueeze(-1).expand(logit.size())
365
+ vocab_size = logit.size()[-1]
366
+ voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
367
+ lm_input = torch.masked_select(pred_ids, ~pred_ids.eq(tokenizer.pad_token_id)).unsqueeze(0)
368
+ lm_prob = torch.nn.functional.softmax(lm_model.forward(lm_input)[0], dim=-1).squeeze(0)
369
+ comb_pred_ids = torch.argmax(lm_prob*voice_prob, dim=-1)
370
+ decoded_results.append(processor.decode(comb_pred_ids))
371
+
372
  batch["predicted"] = decoded_results
373
  batch["target"] = batch["sentence"]
374
  return batch
 
380
  print("CER: {:2f}".format(100 * cer.compute(predictions=result["predicted"], references=result["target"])))
381
  ```
382
 
383
+ `CER: 25.77`.
384
+ `TIME: 06:01 min`