|
import gradio as gr |
|
import torchaudio |
|
import torch |
|
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC |
|
from transformers import Speech2Text2Processor, Speech2Text2ForConditionalGeneration |
|
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification |
|
|
|
|
|
asr_model = Wav2Vec2ForCTC.from_pretrained("facebook/mms-1b-all") |
|
asr_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-1b-all") |
|
|
|
tts_model = Speech2Text2ForConditionalGeneration.from_pretrained("facebook/mms-tts") |
|
tts_processor = Speech2Text2Processor.from_pretrained("facebook/mms-tts") |
|
|
|
lid_model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/mms-lid-1024") |
|
lid_processor = Wav2Vec2Processor.from_pretrained("facebook/mms-lid-1024") |
|
|
|
|
|
def asr_transcribe(audio): |
|
inputs = asr_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
logits = asr_model(**inputs).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcription = asr_processor.batch_decode(predicted_ids) |
|
return transcription[0] |
|
|
|
|
|
def tts_synthesize(text): |
|
inputs = tts_processor(text, return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
generated_ids = tts_model.generate(**inputs) |
|
audio = tts_processor.batch_decode(generated_ids, skip_special_tokens=True) |
|
return audio[0] |
|
|
|
|
|
def identify_language(audio): |
|
inputs = lid_processor(audio, sampling_rate=16000, return_tensors="pt", padding=True) |
|
with torch.no_grad(): |
|
logits = lid_model(**inputs).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
language = lid_processor.batch_decode(predicted_ids) |
|
return language[0] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Tab("ASR"): |
|
gr.Markdown("## Automatic Speech Recognition (ASR)") |
|
audio_input = gr.Audio(source="microphone", type="numpy") |
|
text_output = gr.Textbox(label="Transcription") |
|
gr.Button("Clear", clear_audio_input) |
|
gr.Button("Submit", fn=asr_transcribe, inputs=audio_input, outputs=text_output) |
|
|
|
with gr.Tab("TTS"): |
|
gr.Markdown("## Text-to-Speech (TTS)") |
|
text_input = gr.Textbox(label="Text") |
|
audio_output = gr.Audio(label="Audio Output") |
|
gr.Button("Clear", clear_text_input) |
|
gr.Button("Submit", fn=tts_synthesize, inputs=text_input, outputs=audio_output) |
|
|
|
with gr.Tab("Language ID"): |
|
gr.Markdown("## Language Identification (LangID)") |
|
audio_input = gr.Audio(source="microphone", type="numpy") |
|
language_output = gr.Textbox(label="Identified Language") |
|
gr.Button("Clear", clear_audio_input) |
|
gr.Button("Submit", fn=identify_language, inputs=audio_input, outputs=language_output) |
|
|
|
demo.launch() |
|
|