s2s / app.py
frogcho123's picture
Update app.py
bc7920f
raw
history blame
2.16 kB
import gradio as gr
import whisper
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from gtts import gTTS
import soundfile as sf
import scipy.io.wavfile as wav
import os
def translate_speech_to_speech(input_audio):
# Save the input audio to a temporary file
input_file = "input_audio" + os.path.splitext(input_audio.name)[1]
input_audio.save(input_file)
# Language detection and translation code from the first code snippet
model = whisper.load_model("base")
audio = whisper.load_audio(input_file)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).to(model.device)
_, probs = model.detect_language(mel)
options = whisper.DecodingOptions()
result = whisper.decode(model, mel, options)
text = result.text
lang = max(probs, key=probs.get)
# Translation code from the first code snippet
to_lang = 'ru'
tokenizer = AutoTokenizer.from_pretrained("alirezamsh/small100")
model = AutoModelForSeq2SeqLM.from_pretrained("alirezamsh/small100")
tokenizer.src_lang = lang
encoded_bg = tokenizer(text, return_tensors="pt")
generated_tokens = model.generate(**encoded_bg)
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
# Text-to-speech (TTS) code from the first code snippet
tts = gTTS(text=translated_text, lang=to_lang)
output_file = "translated_speech.wav"
tts.save(output_file)
# Load the translated audio and return as an output
translated_audio, sr = sf.read(output_file, dtype="float32")
translated_audio = (translated_audio * 32767).astype("int16")
return translated_audio, sr
title = "Speech-to-Speech Translator"
input_audio = gr.inputs.Audio(type=["mp3", "wav"])
output_audio = gr.outputs.Audio(type=["mp3", "wav"], sample_rate=44100)
stt_demo = gr.Interface(
fn=translate_speech_to_speech,
inputs=input_audio,
outputs=output_audio,
title=title,
description="Speak in any language, and the translator will convert it to speech in the target language.",
)
if __name__ == "__main__":
stt_demo.launch()