|
|
|
|
|
|
|
import os |
|
import re |
|
import nltk |
|
import torch |
|
import librosa |
|
import tempfile |
|
import subprocess |
|
|
|
import gradio as gr |
|
|
|
from scipy.io import wavfile |
|
from nnet import utils, commons |
|
from transformers import pipeline |
|
from scipy.io.wavfile import write |
|
from faster_whisper import WhisperModel |
|
from nnet.models import SynthesizerTrn as vitsTRN |
|
from nnet.models_vc import SynthesizerTrn as freeTRN |
|
from nnet.mel_processing import mel_spectrogram_torch |
|
from configurations.get_constants import constantConfig |
|
|
|
from speaker_encoder.voice_encoder import SpeakerEncoder |
|
|
|
from df_local.enhance import enhance, init_df, load_audio, save_audio |
|
from configurations.get_hyperparameters import hyperparameterConfig |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
|
|
nltk.download('punkt') |
|
from nltk.tokenize import sent_tokenize |
|
|
|
|
|
|
|
class FreeVCModel: |
|
def __init__(self, config, ptfile, speaker_model, wavLM_model, device='cpu'): |
|
self.hps = utils.get_hparams_from_file(config) |
|
|
|
self.net_g = freeTRN( |
|
self.hps.data.filter_length // 2 + 1, |
|
self.hps.train.segment_size // self.hps.data.hop_length, |
|
**self.hps.model |
|
).to(hyperparameters.device) |
|
_ = self.net_g.eval() |
|
_ = utils.load_checkpoint(ptfile, self.net_g, None, True) |
|
|
|
self.cmodel = utils.get_cmodel(device, wavLM_model) |
|
|
|
if self.hps.model.use_spk: |
|
self.smodel = SpeakerEncoder(speaker_model) |
|
|
|
def convert(self, src, tgt): |
|
fs_src, src_audio = src |
|
fs_tgt, tgt_audio = tgt |
|
|
|
src = f"{constants.temp_audio_folder}/src.wav" |
|
tgt = f"{constants.temp_audio_folder}/tgt.wav" |
|
out = f"{constants.temp_audio_folder}/cnvr.wav" |
|
with torch.no_grad(): |
|
wavfile.write(tgt, fs_tgt, tgt_audio) |
|
wav_tgt, _ = librosa.load(tgt, sr=self.hps.data.sampling_rate) |
|
wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20) |
|
if self.hps.model.use_spk: |
|
g_tgt = self.smodel.embed_utterance(wav_tgt) |
|
g_tgt = torch.from_numpy(g_tgt).unsqueeze(0).to(hyperparameters.device.type) |
|
else: |
|
wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(hyperparameters.device.type) |
|
mel_tgt = mel_spectrogram_torch( |
|
wav_tgt, |
|
self.hps.data.filter_length, |
|
self.hps.data.n_mel_channels, |
|
self.hps.data.sampling_rate, |
|
self.hps.data.hop_length, |
|
self.hps.data.win_length, |
|
self.hps.data.mel_fmin, |
|
self.hps.data.mel_fmax, |
|
) |
|
wavfile.write(src, fs_src, src_audio) |
|
wav_src, _ = librosa.load(src, sr=self.hps.data.sampling_rate) |
|
wav_src = torch.from_numpy(wav_src).unsqueeze(0).to(hyperparameters.device.type) |
|
c = utils.get_content(self.cmodel, wav_src) |
|
|
|
if self.hps.model.use_spk: |
|
audio = self.net_g.infer(c, g=g_tgt) |
|
else: |
|
audio = self.net_g.infer(c, mel=mel_tgt) |
|
audio = audio[0][0].data.cpu().float().numpy() |
|
write(out, 24000, audio) |
|
|
|
return out |
|
|
|
|
|
constants = constantConfig() |
|
hyperparameters = hyperparameterConfig() |
|
|
|
|
|
model, df_state, _ = init_df(hyperparameters.voice_enhacing_model, config_allow_defaults=True) |
|
stt_model = WhisperModel(hyperparameters.stt_model, device=hyperparameters.device.type, compute_type="float32") |
|
|
|
trans_model = AutoModelForSeq2SeqLM.from_pretrained(constants.model_name_dict[hyperparameters.nllb_model], torch_dtype=torch.bfloat16).to(hyperparameters.device) |
|
trans_tokenizer = AutoTokenizer.from_pretrained(constants.model_name_dict[hyperparameters.nllb_model]) |
|
|
|
modelConvertSpeech = FreeVCModel(config=hyperparameters.text2speech_config, ptfile=hyperparameters.text2speech_model, |
|
speaker_model=hyperparameters.text2speech_encoder, wavLM_model=hyperparameters.wavlm_model, |
|
device=hyperparameters.device.type) |
|
|
|
|
|
|
|
def download(lang, lang_directory): |
|
|
|
if not os.path.exists(f"{lang_directory}/{lang}"): |
|
cmd = ";".join([ |
|
f"wget {constants.language_download_web}/{lang}.tar.gz -O {lang_directory}/{lang}.tar.gz", |
|
f"tar zxvf {lang_directory}/{lang}.tar.gz -C {lang_directory}" |
|
]) |
|
subprocess.check_output(cmd, shell=True) |
|
try: |
|
os.remove(f"{lang_directory}/{lang}.tar.gz") |
|
except: |
|
pass |
|
return f"{lang_directory}/{lang}" |
|
|
|
def preprocess_char(text, lang=None): |
|
""" |
|
Special treatement of characters in certain languages |
|
""" |
|
if lang == 'ron': |
|
text = text.replace("ț", "ţ") |
|
return text |
|
|
|
def preprocess_text(txt, text_mapper, hps, uroman_dir=None, lang=None): |
|
txt = preprocess_char(txt, lang=lang) |
|
is_uroman = hps.data.training_files.split('.')[-1] == 'uroman' |
|
if is_uroman: |
|
txt = text_mapper.uromanize(txt, f'{uroman_dir}/bin/uroman.pl') |
|
|
|
txt = txt.lower() |
|
txt = text_mapper.filter_oov(txt) |
|
return txt |
|
|
|
def detect_language(text,LID): |
|
predictions = LID.predict(text) |
|
detected_lang_code = predictions[0][0].replace("__label__", "") |
|
return detected_lang_code |
|
|
|
|
|
class TextMapper(object): |
|
def __init__(self, vocab_file): |
|
self.symbols = [x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()] |
|
self.SPACE_ID = self.symbols.index(" ") |
|
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)} |
|
self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)} |
|
|
|
def text_to_sequence(self, text, cleaner_names): |
|
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. |
|
Args: |
|
text: string to convert to a sequence |
|
cleaner_names: names of the cleaner functions to run the text through |
|
Returns: |
|
List of integers corresponding to the symbols in the text |
|
''' |
|
sequence = [] |
|
clean_text = text.strip() |
|
for symbol in clean_text: |
|
symbol_id = self._symbol_to_id[symbol] |
|
sequence += [symbol_id] |
|
return sequence |
|
|
|
def uromanize(self, text, uroman_pl): |
|
with tempfile.NamedTemporaryFile() as tf, \ |
|
tempfile.NamedTemporaryFile() as tf2: |
|
with open(tf.name, "w") as f: |
|
f.write("\n".join([text])) |
|
cmd = f"perl " + uroman_pl |
|
cmd += f" -l xxx " |
|
cmd += f" < {tf.name} > {tf2.name}" |
|
os.system(cmd) |
|
outtexts = [] |
|
with open(tf2.name) as f: |
|
for line in f: |
|
line = re.sub(r"\s+", " ", line).strip() |
|
outtexts.append(line) |
|
outtext = outtexts[0] |
|
return outtext |
|
|
|
def get_text(self, text, hps): |
|
text_norm = self.text_to_sequence(text, hps.data.text_cleaners) |
|
if hps.data.add_blank: |
|
text_norm = commons.intersperse(text_norm, 0) |
|
text_norm = torch.LongTensor(text_norm) |
|
return text_norm |
|
|
|
def filter_oov(self, text): |
|
val_chars = self._symbol_to_id |
|
txt_filt = "".join(list(filter(lambda x: x in val_chars, text))) |
|
return txt_filt |
|
|
|
def speech_to_text(audio_file): |
|
try: |
|
fs, audio = audio_file |
|
wavfile.write(constants.input_speech_file, fs, audio) |
|
audio0, _ = load_audio(constants.input_speech_file, sr=df_state.sr()) |
|
|
|
|
|
enhanced = enhance(model, df_state, audio0) |
|
save_audio(constants.enhanced_speech_file, enhanced, df_state.sr()) |
|
|
|
segments, info = stt_model.transcribe(constants.enhanced_speech_file) |
|
|
|
speech_text = '' |
|
for segment in segments: |
|
speech_text = f'{speech_text}{segment.text}' |
|
try: |
|
source_lang_nllb = [k for k, v in constants.flores_codes_to_tts_codes.items() if v[:2] == info.language][0] |
|
except: |
|
source_lang_nllb = 'language cant be determined, select manually' |
|
|
|
|
|
return speech_text, gr.Dropdown.update(value=source_lang_nllb) |
|
except: |
|
return '', gr.Dropdown.update(value='English') |
|
|
|
|
|
def text_to_speech(text, target_lang): |
|
txt = text |
|
|
|
|
|
LANG = constants.flores_codes_to_tts_codes[target_lang] |
|
ckpt_dir = download(LANG, lang_directory=constants.language_directory) |
|
|
|
vocab_file = f"{ckpt_dir}/{constants.language_vocab_text}" |
|
config_file = f"{ckpt_dir}/{constants.language_vocab_configuration}" |
|
hps = utils.get_hparams_from_file(config_file) |
|
text_mapper = TextMapper(vocab_file) |
|
net_g = vitsTRN( |
|
len(text_mapper.symbols), |
|
hps.data.filter_length // 2 + 1, |
|
hps.train.segment_size // hps.data.hop_length, |
|
**hps.model) |
|
net_g.to(hyperparameters.device) |
|
_ = net_g.eval() |
|
|
|
g_pth = f"{ckpt_dir}/{constants.language_vocab_model}" |
|
|
|
_ = utils.load_checkpoint(g_pth, net_g, None) |
|
|
|
txt = preprocess_text(txt, text_mapper, hps, lang=LANG, uroman_dir=constants.uroman_directory) |
|
stn_tst = text_mapper.get_text(txt, hps) |
|
with torch.no_grad(): |
|
x_tst = stn_tst.unsqueeze(0).to(hyperparameters.device) |
|
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(hyperparameters.device) |
|
hyp = net_g.infer( |
|
x_tst, x_tst_lengths, noise_scale=.667, |
|
noise_scale_w=0.8, length_scale=1.0 |
|
)[0][0,0].cpu().float().numpy() |
|
|
|
return hps.data.sampling_rate, hyp |
|
|
|
def translation(audio, text, source_lang_nllb, target_code_nllb, output_type, sentence_mode): |
|
target_code = constants.flores_codes[target_code_nllb] |
|
translator = pipeline('translation', model=trans_model, tokenizer=trans_tokenizer, src_lang=source_lang_nllb, tgt_lang=target_code, device=hyperparameters.device) |
|
|
|
|
|
if sentence_mode == "Sentence-wise": |
|
sentences = sent_tokenize(text) |
|
translated_sentences = [] |
|
for sentence in sentences: |
|
translated_sentence = translator(sentence, max_length=400)[0]['translation_text'] |
|
translated_sentences.append(translated_sentence) |
|
output = ' '.join(translated_sentences) |
|
else: |
|
output = translator(text, max_length=1024)[0]['translation_text'] |
|
|
|
|
|
fs_out, audio_out = text_to_speech(output, target_code_nllb) |
|
|
|
if output_type == 'own voice': |
|
out_file = modelConvertSpeech.convert((fs_out, audio_out), audio) |
|
return output, out_file |
|
|
|
wavfile.write(constants.text2speech_wavfile, fs_out, audio_out) |
|
return output, constants.text2speech_wavfile |
|
|
|
with gr.Blocks(title = "Octopus Translation App") as octopus_translator: |
|
with gr.Row(): |
|
audio_file = gr.Audio(source="microphone") |
|
|
|
with gr.Row(): |
|
input_text = gr.Textbox(label="Input text") |
|
source_language = gr.Dropdown(list(constants.flores_codes.keys()), value='English', label='Source (Autoselected)', interactive=True) |
|
|
|
with gr.Row(): |
|
output_text = gr.Textbox(label='Translated text') |
|
target_language = gr.Dropdown(list(constants.flores_codes.keys()), value='German', label='Target', interactive=True) |
|
|
|
|
|
with gr.Row(): |
|
output_speech = gr.Audio(label='Translated speech') |
|
translate_button = gr.Button('Translate') |
|
|
|
|
|
with gr.Row(): |
|
enhance_audio = gr.Radio(['yes', 'no'], value='yes', label='Enhance input voice', interactive=True) |
|
input_type = gr.Radio(['Whole text', 'Sentence-wise'],value='Sentence-wise', label="Translation Mode", interactive=True) |
|
output_audio_type = gr.Radio(['standard speaker', 'voice transfer'], value='voice transfer', label='Enhance output voice', interactive=True) |
|
|
|
audio_file.change(speech_to_text, |
|
inputs=[audio_file], |
|
outputs=[input_text, source_language]) |
|
|
|
translate_button.click(translation, |
|
inputs=[audio_file, input_text, |
|
source_language, target_language, |
|
output_audio_type, input_type], |
|
outputs=[output_text, output_speech]) |
|
|
|
octopus_translator.launch(share=False) |
|
|