Spaces:
Runtime error
Runtime error
File size: 5,589 Bytes
056c529 d320d92 056c529 6aa1052 056c529 3715573 056c529 3715573 056c529 d320d92 056c529 d320d92 056c529 d320d92 056c529 d320d92 056c529 d320d92 056c529 d320d92 056c529 71dc112 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import gradio as gr
import torchaudio
from datasets import load_dataset, load_metric
from transformers import (
Wav2Vec2ForCTC,
Wav2Vec2Processor,
AutoTokenizer,
AutoModelWithLMHead
)
import torch
import re
import sys
import soundfile as sf
model_name = "voidful/wav2vec2-xlsr-multilingual-56"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor_name = "voidful/wav2vec2-xlsr-multilingual-56"
import pickle
with open("lang_ids.pk", 'rb') as output:
lang_ids = pickle.load(output)
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
processor = Wav2Vec2Processor.from_pretrained(processor_name)
model.eval()
def load_file_to_data(file,sampling_rate=16_000):
batch = {}
speech, _ = torchaudio.load(file)
if sampling_rate != '16_000' or sampling_rate != '16000':
resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16_000)
batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
batch["sampling_rate"] = resampler.new_freq
else:
batch["speech"] = speech.squeeze(0).numpy()
batch["sampling_rate"] = '16000'
return batch
def predict(data):
data=load_file_to_data(data,sampling_rate=16_000)
features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
input_values = features.input_values.to(device)
attention_mask = features.attention_mask.to(device)
with torch.no_grad():
logits = model(input_values, attention_mask=attention_mask).logits
decoded_results = []
for logit in logits:
pred_ids = torch.argmax(logit, dim=-1)
mask = pred_ids.ge(1).unsqueeze(-1).expand(logit.size())
vocab_size = logit.size()[-1]
voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
comb_pred_ids = torch.argmax(voice_prob, dim=-1)
decoded_results.append(processor.decode(comb_pred_ids))
return decoded_results
def predict_lang_specific(data,lang_code):
data=load_file_to_data(data,sampling_rate=16_000)
features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
input_values = features.input_values.to(device)
attention_mask = features.attention_mask.to(device)
with torch.no_grad():
logits = model(input_values, attention_mask=attention_mask).logits
decoded_results = []
for logit in logits:
pred_ids = torch.argmax(logit, dim=-1)
mask = ~pred_ids.eq(processor.tokenizer.pad_token_id).unsqueeze(-1).expand(logit.size())
vocab_size = logit.size()[-1]
voice_prob = torch.nn.functional.softmax((torch.masked_select(logit, mask).view(-1,vocab_size)),dim=-1)
filtered_input = pred_ids[pred_ids!=processor.tokenizer.pad_token_id].view(1,-1).to(device)
if len(filtered_input[0]) == 0:
decoded_results.append("")
else:
lang_mask = torch.empty(voice_prob.shape[-1]).fill_(0)
lang_index = torch.tensor(sorted(lang_ids[lang_code]))
lang_mask.index_fill_(0, lang_index, 1)
lang_mask = lang_mask.to(device)
comb_pred_ids = torch.argmax(lang_mask*voice_prob, dim=-1)
decoded_results.append(processor.decode(comb_pred_ids))
return decoded_results
'''def recognition(audio_file):
print("audio_file", audio_file.name)
speech, rate = sp.load_speech_with_file(audio_file.name)
result = sp.predict_audio_file(speech)
print(result)
return result
'''
#predict(load_file_to_data('audio file path',sampling_rate=16_000)) # beware of the audio file sampling rate
#predict_lang_specific(load_file_to_data('audio file path',sampling_rate=16_000),'en') # beware of the audio file sampling rate
with gr.Blocks() as demo:
gr.Markdown("multilingual Speech Recognition")
with gr.Tab("Auto"):
gr.Markdown("automatically detects your language")
inputs_speech =gr.Audio(source="upload", type="filepath", optional=True)
output_transcribe = gr.HTML(label="")
transcribe_audio= gr.Button("Submit")
with gr.Tab("manual"):
gr.Markdown("set your speech language")
inputs_speech1 =[
gr.Audio(source="upload", type="filepath"),
gr.Dropdown(choices=["ar","as","br","ca","cnh","cs","cv","cy","de","dv","el","en","eo","es","et","eu","fa","fi","fr","fy-NL","ga-IE","hi","hsb","hu","ia","id","it","ja","ka","ky","lg","lt","lv","mn","mt","nl","or","pa-IN","pl","pt","rm-sursilv","rm-vallader","ro","ru","sah","sl","sv-SE","ta","th","tr","tt","uk","vi","zh-CN","zh-HK","zh-TW"]
,value="fa",label="language code")
]
output_transcribe1 = gr.Textbox(label="output")
transcribe_audio1= gr.Button("Submit")
'''with gr.Tab("Auto1"):
gr.Markdown("automatically detects your language")
inputs_speech2 = gr.Audio(label="Input Audio", type="file")
output_transcribe2 = gr.Textbox()
transcribe_audio2= gr.Button("Submit")'''
transcribe_audio.click(fn=predict,
inputs=inputs_speech,
outputs=output_transcribe)
transcribe_audio1.click(fn=predict_lang_specific,
inputs=inputs_speech1 ,
outputs=output_transcribe1 )
'''transcribe_audio2.click(fn=recognition,
inputs=inputs_speech2 ,
outputs=output_transcribe2 )'''
if __name__ == "__main__":
demo.launch()
|