--- language: - multilingual - ar - as - br - ca - cnh - cs - cv - cy - de - dv - el - en - eo - es - et - eu - fa - fi - fr - hi - hsb - hu - ia - id - ja - ka - ky - lg - lt - ly - mn - mt - nl - or - pl - pt - ro - ru - sah - sl - ta - th - tr - tt - uk - vi language_bcp47: - fy-NL - ga-IE - pa-IN - rm-sursilv - rm-vallader - sy-SE - zh-CN - zh-HK - zh-TW datasets: - common_voice tags: - audio - automatic-speech-recognition - hf-asr-leaderboard - robust-speech-event - speech - xlsr-fine-tuning-week license: apache-2.0 model-index: - name: XLSR Wav2Vec2 for 56 language by Voidful results: - task: name: Speech Recognition type: automatic-speech-recognition dataset: name: Common Voice type: common_voice metrics: - name: Test CER type: cer value: 23.21 --- # wav2vec2-xlsr-multilingual-56 *56 language, 1 model Multilingual ASR* Fine-tuned [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) on 56 language using the [Common Voice](https://huggingface.co/datasets/common_voice). When using this model, make sure that your speech input is sampled at 16kHz. For more detail: [https://github.com/voidful/wav2vec2-xlsr-multilingual-56](https://github.com/voidful/wav2vec2-xlsr-multilingual-56) ## Env setup: ``` !pip install torchaudio !pip install datasets transformers !pip install asrp !wget -O lang_ids.pk https://huggingface.co/voidful/wav2vec2-xlsr-multilingual-56/raw/main/lang_ids.pk ``` ## Usage ``` import torchaudio from datasets import load_dataset, load_metric from transformers import ( Wav2Vec2ForCTC, Wav2Vec2Processor, AutoTokenizer, AutoModelWithLMHead ) import torch import re import sys import soundfile as sf model_name = "voidful/wav2vec2-xlsr-multilingual-56" device = "cuda" processor_name = "voidful/wav2vec2-xlsr-multilingual-56" import pickle with open("lang_ids.pk", 'rb') as output: lang_ids = pickle.load(output) model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device) processor = Wav2Vec2Processor.from_pretrained(processor_name) model.eval() def load_file_to_data(file,sampling_rate=16_000): batch = {} speech, _ = torchaudio.load(file) if sampling_rate != '16_000' or sampling_rate != '16000': resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16_000) batch["speech"] = resampler.forward(speech.squeeze(0)).numpy() batch["sampling_rate"] = resampler.new_freq else: batch["speech"] = speech.squeeze(0).numpy() batch["sampling_rate"] = '16000' return batch def predict(data): features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt") input_values = features.input_values.to(device) attention_mask = features.attention_mask.to(device) with torch.no_grad(): logits = model(input_values, attention_mask=attention_mask).logits decoded_results = [] for logit in logits: pred_ids = torch.argmax(logit, dim=-1) mask = pred_ids.ge(1).unsqueeze(-1).expand(logit.size()) vocab_size = logit.size()[-1] voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1) comb_pred_ids = torch.argmax(voice_prob, dim=-1) decoded_results.append(processor.decode(comb_pred_ids)) return decoded_results def predict_lang_specific(data,lang_code): features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt") input_values = features.input_values.to(device) attention_mask = features.attention_mask.to(device) with torch.no_grad(): logits = model(input_values, attention_mask=attention_mask).logits decoded_results = [] for logit in logits: pred_ids = torch.argmax(logit, dim=-1) mask = ~pred_ids.eq(processor.tokenizer.pad_token_id).unsqueeze(-1).expand(logit.size()) vocab_size = logit.size()[-1] voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1) filtered_input = pred_ids[pred_ids!=processor.tokenizer.pad_token_id].view(1,-1).to(device) if len(filtered_input[0]) == 0: decoded_results.append("") else: lang_mask = torch.empty(voice_prob.shape[-1]).fill_(0) lang_index = torch.tensor(sorted(lang_ids[lang_code])) lang_mask.index_fill_(0, lang_index, 1) lang_mask = lang_mask.to(device) comb_pred_ids = torch.argmax(lang_mask*voice_prob, dim=-1) decoded_results.append(processor.decode(comb_pred_ids)) return decoded_results predict(load_file_to_data('audio file path',sampling_rate=16_000)) # beware of the audio file sampling rate predict_lang_specific(load_file_to_data('audio file path',sampling_rate=16_000),'en') # beware of the audio file sampling rate ``` ## Result | Common Voice Languages | Num. of data | Hour | WER | CER | |------------------------|--------------|--------|--------|-------| | ar | 21744 | 81.5 | 75.29 | 31.23 | | as | 394 | 1.1 | 95.37 | 46.05 | | br | 4777 | 7.4 | 93.79 | 41.16 | | ca | 301308 | 692.8 | 24.80 | 10.39 | | cnh | 1563 | 2.4 | 68.11 | 23.10 | | cs | 9773 | 39.5 | 67.86 | 12.57 | | cv | 1749 | 5.9 | 95.43 | 34.03 | | cy | 11615 | 106.7 | 67.03 | 23.97 | | de | 262113 | 822.8 | 27.03 | 6.50 | | dv | 4757 | 18.6 | 92.16 | 30.15 | | el | 3717 | 11.1 | 94.48 | 58.67 | | en | 580501 | 1763.6 | 34.87 | 14.84 | | eo | 28574 | 162.3 | 37.77 | 6.23 | | es | 176902 | 337.7 | 19.63 | 5.41 | | et | 5473 | 35.9 | 86.87 | 20.79 | | eu | 12677 | 90.2 | 44.80 | 7.32 | | fa | 12806 | 290.6 | 53.81 | 15.09 | | fi | 875 | 2.6 | 93.78 | 27.57 | | fr | 314745 | 664.1 | 33.16 | 13.94 | | fy-NL | 6717 | 27.2 | 72.54 | 26.58 | | ga-IE | 1038 | 3.5 | 92.57 | 51.02 | | hi | 292 | 2.0 | 90.95 | 57.43 | | hsb | 980 | 2.3 | 89.44 | 27.19 | | hu | 4782 | 9.3 | 97.15 | 36.75 | | ia | 5078 | 10.4 | 52.00 | 11.35 | | id | 3965 | 9.9 | 82.50 | 22.82 | | it | 70943 | 178.0 | 39.09 | 8.72 | | ja | 1308 | 8.2 | 99.21 | 62.06 | | ka | 1585 | 4.0 | 90.53 | 18.57 | | ky | 3466 | 12.2 | 76.53 | 19.80 | | lg | 1634 | 17.1 | 98.95 | 43.84 | | lt | 1175 | 3.9 | 92.61 | 26.81 | | lv | 4554 | 6.3 | 90.34 | 30.81 | | mn | 4020 | 11.6 | 82.68 | 30.14 | | mt | 3552 | 7.8 | 84.18 | 22.96 | | nl | 14398 | 71.8 | 57.18 | 19.01 | | or | 517 | 0.9 | 90.93 | 27.34 | | pa-IN | 255 | 0.8 | 87.95 | 42.03 | | pl | 12621 | 112.0 | 56.14 | 12.06 | | pt | 11106 | 61.3 | 53.24 | 16.32 | | rm-sursilv | 2589 | 5.9 | 78.17 | 23.31 | | rm-vallader | 931 | 2.3 | 73.67 | 21.76 | | ro | 4257 | 8.7 | 83.84 | 21.95 | | ru | 23444 | 119.1 | 61.83 | 15.18 | | sah | 1847 | 4.4 | 94.38 | 38.46 | | sl | 2594 | 6.7 | 84.21 | 20.54 | | sv-SE | 4350 | 20.8 | 83.68 | 30.79 | | ta | 3788 | 18.4 | 84.19 | 21.60 | | th | 4839 | 11.7 | 141.87 | 37.16 | | tr | 3478 | 22.3 | 66.77 | 15.55 | | tt | 13338 | 26.7 | 86.80 | 33.57 | | uk | 7271 | 39.4 | 70.23 | 14.34 | | vi | 421 | 1.7 | 96.06 | 66.25 | | zh-CN | 27284 | 58.7 | 89.67 | 23.96 | | zh-HK | 12678 | 92.1 | 81.77 | 18.82 | | zh-TW | 6402 | 56.6 | 85.08 | 29.07 |