SSMTDemo / app.py
nisheeth's picture
Upload 7 files
3fe4d91 verified
# load the libraries for the application
# -------------------------------------------
import os
import re
import nltk
import torch
import librosa
import tempfile
import subprocess
import gradio as gr
from scipy.io import wavfile
from nnet import utils, commons
from transformers import pipeline
from scipy.io.wavfile import write
from faster_whisper import WhisperModel
from nnet.models import SynthesizerTrn as vitsTRN
from nnet.models_vc import SynthesizerTrn as freeTRN
from nnet.mel_processing import mel_spectrogram_torch
from configurations.get_constants import constantConfig
from speaker_encoder.voice_encoder import SpeakerEncoder
from df_local.enhance import enhance, init_df, load_audio, save_audio
from configurations.get_hyperparameters import hyperparameterConfig
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
# making the FreeVC function
# ---------------------------------
class FreeVCModel:
def __init__(self, config, ptfile, speaker_model, wavLM_model, device='cpu'):
self.hps = utils.get_hparams_from_file(config)
self.net_g = freeTRN(
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
**self.hps.model
).to(hyperparameters.device)
_ = self.net_g.eval()
_ = utils.load_checkpoint(ptfile, self.net_g, None, True)
self.cmodel = utils.get_cmodel(device, wavLM_model)
if self.hps.model.use_spk:
self.smodel = SpeakerEncoder(speaker_model)
def convert(self, src, tgt):
fs_src, src_audio = src
fs_tgt, tgt_audio = tgt
src = f"{constants.temp_audio_folder}/src.wav"
tgt = f"{constants.temp_audio_folder}/tgt.wav"
out = f"{constants.temp_audio_folder}/cnvr.wav"
with torch.no_grad():
wavfile.write(tgt, fs_tgt, tgt_audio)
wav_tgt, _ = librosa.load(tgt, sr=self.hps.data.sampling_rate)
wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)
if self.hps.model.use_spk:
g_tgt = self.smodel.embed_utterance(wav_tgt)
g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(hyperparameters.device.type)
else:
wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(hyperparameters.device.type)
mel_tgt = mel_spectrogram_torch(
wav_tgt,
self.hps.data.filter_length,
self.hps.data.n_mel_channels,
self.hps.data.sampling_rate,
self.hps.data.hop_length,
self.hps.data.win_length,
self.hps.data.mel_fmin,
self.hps.data.mel_fmax,
)
wavfile.write(src, fs_src, src_audio)
wav_src, _ = librosa.load(src, sr=self.hps.data.sampling_rate)
wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(hyperparameters.device.type)
c = utils.get_content(self.cmodel, wav_src)
if self.hps.model.use_spk:
audio = self.net_g.infer(c, g=g_tgt)
else:
audio = self.net_g.infer(c, mel=mel_tgt)
audio = audio[0][0].data.cpu().float().numpy()
write(out, 24000, audio)
return out
# load the system configurations
constants = constantConfig()
hyperparameters = hyperparameterConfig()
# load the models
model, df_state, _ = init_df(hyperparameters.voice_enhacing_model, config_allow_defaults=True) # voice enhancing model
stt_model = WhisperModel(hyperparameters.stt_model, device=hyperparameters.device.type, compute_type="float32") #speech to text model
trans_model = AutoModelForSeq2SeqLM.from_pretrained(constants.model_name_dict[hyperparameters.nllb_model], torch_dtype=torch.bfloat16).to(hyperparameters.device)
trans_tokenizer = AutoTokenizer.from_pretrained(constants.model_name_dict[hyperparameters.nllb_model])
modelConvertSpeech = FreeVCModel(config=hyperparameters.text2speech_config, ptfile=hyperparameters.text2speech_model,
speaker_model=hyperparameters.text2speech_encoder, wavLM_model=hyperparameters.wavlm_model,
device=hyperparameters.device.type)
# download the language model if doesn't existing
# ----------------------------------------------------
def download(lang, lang_directory):
if not os.path.exists(f"{lang_directory}/{lang}"):
cmd = ";".join([
f"wget {constants.language_download_web}/{lang}.tar.gz -O {lang_directory}/{lang}.tar.gz",
f"tar zxvf {lang_directory}/{lang}.tar.gz -C {lang_directory}"
])
subprocess.check_output(cmd, shell=True)
try:
os.remove(f"{lang_directory}/{lang}.tar.gz")
except:
pass
return f"{lang_directory}/{lang}"
def preprocess_char(text, lang=None):
"""
Special treatement of characters in certain languages
"""
if lang == 'ron':
text = text.replace("ț", "ţ")
return text
def preprocess_text(txt, text_mapper, hps, uroman_dir=None, lang=None):
txt = preprocess_char(txt, lang=lang)
is_uroman = hps.data.training_files.split('.')[-1] == 'uroman'
if is_uroman:
txt = text_mapper.uromanize(txt, f'{uroman_dir}/bin/uroman.pl')
txt = txt.lower()
txt = text_mapper.filter_oov(txt)
return txt
def detect_language(text,LID):
predictions = LID.predict(text)
detected_lang_code = predictions[0][0].replace("__label__", "")
return detected_lang_code
# text to speech
class TextMapper(object):
def __init__(self, vocab_file):
self.symbols = [x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()]
self.SPACE_ID = self.symbols.index(" ")
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
def text_to_sequence(self, text, cleaner_names):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
'''
sequence = []
clean_text = text.strip()
for symbol in clean_text:
symbol_id = self._symbol_to_id[symbol]
sequence += [symbol_id]
return sequence
def uromanize(self, text, uroman_pl):
with tempfile.NamedTemporaryFile() as tf, \
tempfile.NamedTemporaryFile() as tf2:
with open(tf.name, "w") as f:
f.write("\n".join([text]))
cmd = f"perl " + uroman_pl
cmd += f" -l xxx "
cmd += f" < {tf.name} > {tf2.name}"
os.system(cmd)
outtexts = []
with open(tf2.name) as f:
for line in f:
line = re.sub(r"\s+", " ", line).strip()
outtexts.append(line)
outtext = outtexts[0]
return outtext
def get_text(self, text, hps):
text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def filter_oov(self, text):
val_chars = self._symbol_to_id
txt_filt = "".join(list(filter(lambda x: x in val_chars, text)))
return txt_filt
def speech_to_text(audio_file):
try:
fs, audio = audio_file
wavfile.write(constants.input_speech_file, fs, audio)
audio0, _ = load_audio(constants.input_speech_file, sr=df_state.sr())
# Enhance the SNR of the audio
enhanced = enhance(model, df_state, audio0)
save_audio(constants.enhanced_speech_file, enhanced, df_state.sr())
segments, info = stt_model.transcribe(constants.enhanced_speech_file)
speech_text = ''
for segment in segments:
speech_text = f'{speech_text}{segment.text}'
try:
source_lang_nllb = [k for k, v in constants.flores_codes_to_tts_codes.items() if v[:2] == info.language][0]
except:
source_lang_nllb = 'language cant be determined, select manually'
# text translation
return speech_text, gr.Dropdown.update(value=source_lang_nllb)
except:
return '', gr.Dropdown.update(value='English')
# Text tp speech
def text_to_speech(text, target_lang):
txt = text
# LANG = get_target_tts_lang(target_lang)
LANG = constants.flores_codes_to_tts_codes[target_lang]
ckpt_dir = download(LANG, lang_directory=constants.language_directory)
vocab_file = f"{ckpt_dir}/{constants.language_vocab_text}"
config_file = f"{ckpt_dir}/{constants.language_vocab_configuration}"
hps = utils.get_hparams_from_file(config_file)
text_mapper = TextMapper(vocab_file)
net_g = vitsTRN(
len(text_mapper.symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model)
net_g.to(hyperparameters.device)
_ = net_g.eval()
g_pth = f"{ckpt_dir}/{constants.language_vocab_model}"
_ = utils.load_checkpoint(g_pth, net_g, None)
txt = preprocess_text(txt, text_mapper, hps, lang=LANG, uroman_dir=constants.uroman_directory)
stn_tst = text_mapper.get_text(txt, hps)
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).to(hyperparameters.device)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(hyperparameters.device)
hyp = net_g.infer(
x_tst, x_tst_lengths, noise_scale=.667,
noise_scale_w=0.8, length_scale=1.0
)[0][0,0].cpu().float().numpy()
return hps.data.sampling_rate, hyp
def translation(audio, text, source_lang_nllb, target_code_nllb, output_type, sentence_mode):
target_code = constants.flores_codes[target_code_nllb]
translator = pipeline('translation', model=trans_model, tokenizer=trans_tokenizer, src_lang=source_lang_nllb, tgt_lang=target_code, device=hyperparameters.device)
# output = translator(text, max_length=400)[0]['translation_text']
if sentence_mode == "Sentence-wise":
sentences = sent_tokenize(text)
translated_sentences = []
for sentence in sentences:
translated_sentence = translator(sentence, max_length=400)[0]['translation_text']
translated_sentences.append(translated_sentence)
output = ' '.join(translated_sentences)
else:
output = translator(text, max_length=1024)[0]['translation_text']
# get the text to speech
fs_out, audio_out = text_to_speech(output, target_code_nllb)
if output_type == 'own voice':
out_file = modelConvertSpeech.convert((fs_out, audio_out), audio)
return output, out_file
wavfile.write(constants.text2speech_wavfile, fs_out, audio_out)
return output, constants.text2speech_wavfile
with gr.Blocks(title = "Octopus Translation App") as octopus_translator:
with gr.Row():
audio_file = gr.Audio(source="microphone")
with gr.Row():
input_text = gr.Textbox(label="Input text")
source_language = gr.Dropdown(list(constants.flores_codes.keys()), value='English', label='Source (Autoselected)', interactive=True)
with gr.Row():
output_text = gr.Textbox(label='Translated text')
target_language = gr.Dropdown(list(constants.flores_codes.keys()), value='German', label='Target', interactive=True)
with gr.Row():
output_speech = gr.Audio(label='Translated speech')
translate_button = gr.Button('Translate')
with gr.Row():
enhance_audio = gr.Radio(['yes', 'no'], value='yes', label='Enhance input voice', interactive=True)
input_type = gr.Radio(['Whole text', 'Sentence-wise'],value='Sentence-wise', label="Translation Mode", interactive=True)
output_audio_type = gr.Radio(['standard speaker', 'voice transfer'], value='voice transfer', label='Enhance output voice', interactive=True)
audio_file.change(speech_to_text,
inputs=[audio_file],
outputs=[input_text, source_language])
translate_button.click(translation,
inputs=[audio_file, input_text,
source_language, target_language,
output_audio_type, input_type],
outputs=[output_text, output_speech])
octopus_translator.launch(share=False)