Spaces:
Build error
Build error
""" | |
TODO: | |
+ [x] Load Configuration | |
+ [ ] Checking | |
+ [ ] Better saving directory | |
""" | |
import numpy as np | |
from pathlib import Path | |
import torch.nn as nn | |
import torch | |
import torchaudio | |
from transformers import pipeline | |
from pathlib import Path | |
import pdb | |
# local import | |
import sys | |
from espnet2.bin.tts_inference import Text2Speech | |
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
sys.path.append("src") | |
import gradio as gr | |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
processor = AutoProcessor.from_pretrained("KevinGeng/whipser_medium_en_PAL300_step25") | |
model = AutoModelForSpeechSeq2Seq.from_pretrained("KevinGeng/whipser_medium_en_PAL300_step25") | |
transcriber = pipeline("automatic-speech-recognition", model="KevinGeng/whipser_medium_en_PAL300_step25") | |
# Text2Mel models | |
# @title English multi-speaker pretrained model { run: "auto" } | |
lang = "English" | |
vits_tag = "kan-bayashi/libritts_xvector_vits" | |
ft2_tag = "kan-bayashi/libritts_xvector_conformer_fastspeech2" | |
transformer_tag = "kan-bayashi/libritts_xvector_transformer" | |
# !!! vits needs no vocoder !!! | |
# Local Text2Mel models | |
vits_config_local = "TTS_models/libritts_xvector_vits/config.yaml" | |
vits_model_local = "TTS_models/libritts_xvector_vits/train.total_count.ave_10best.pth" | |
# TODO | |
ft2_config_local = "" | |
ft2_model_local= "" | |
transformer_config_local = "" | |
transformer_config_local = "" | |
# Vocoders | |
vocoder_tag = "parallel_wavegan/vctk_parallel_wavegan.v1.long" # @param ["none", "parallel_wavegan/vctk_parallel_wavegan.v1.long", "parallel_wavegan/vctk_multi_band_melgan.v2", "parallel_wavegan/vctk_style_melgan.v1", "parallel_wavegan/vctk_hifigan.v1", "parallel_wavegan/libritts_parallel_wavegan.v1.long", "parallel_wavegan/libritts_multi_band_melgan.v2", "parallel_wavegan/libritts_hifigan.v1", "parallel_wavegan/libritts_style_melgan.v1"] {type:"string"} | |
hifigan_vocoder_tag = "parallel_wavegan/parallel_wavegan/libritts_hifigan.v1" # @param ["none", "parallel_wavegan/vctk_parallel_wavegan.v1.long", "parallel_wavegan/vctk_multi_band_melgan.v2", "parallel_wavegan/vctk_style_melgan.v1", "parallel_wavegan/vctk_hifigan.v1", "parallel_wavegan/libritts_parallel_wavegan.v1.long", "parallel_wavegan/libritts_multi_band_melgan.v2", "parallel_wavegan/libritts_hifigan.v1", "parallel_wavegan/libritts_style_melgan.v1"] {type:"string"} | |
# Local Vocoders | |
## Make sure the use parallel_wavegan as prefix (PWG feature) | |
vocoder_tag_local = "parallel_wavegan/vctk_parallel_wavegan.v1.long" | |
hifigan_vocoder_tag_local = "parallel_wavegan/libritts_hifigan.v1" | |
from espnet2.bin.tts_inference import Text2Speech | |
from espnet2.utils.types import str_or_none | |
# local import | |
text2speech = Text2Speech.from_pretrained( | |
train_config = vits_config_local, | |
model_file=vits_model_local, | |
device="cuda", | |
use_att_constraint=False, | |
backward_window=1, | |
forward_window=3, | |
speed_control_alpha=1.0, | |
) | |
# Fastspeech2 | |
ft2_text2speech = Text2Speech.from_pretrained( | |
model_tag=ft2_tag, | |
vocoder_tag=str_or_none(vocoder_tag_local), | |
device="cuda", | |
use_att_constraint=False, | |
backward_window=1, | |
forward_window=3, | |
speed_control_alpha=1.0, | |
) | |
# Fastspeech2 + hifigan | |
ft2_text2speech_hifi = Text2Speech.from_pretrained( | |
model_tag=ft2_tag, | |
vocoder_tag=str_or_none(hifigan_vocoder_tag_local), | |
device="cuda", | |
use_att_constraint=False, | |
backward_window=1, | |
forward_window=3, | |
speed_control_alpha=1.0, | |
) | |
# transformer tag | |
transformer_text2speech = Text2Speech.from_pretrained( | |
model_tag=transformer_tag, | |
vocoder_tag=str_or_none(vocoder_tag_local), | |
device="cuda", | |
use_att_constraint=False, | |
backward_window=1, | |
forward_window=3, | |
speed_control_alpha=1.0, | |
) | |
import glob | |
import os | |
import numpy as np | |
import kaldiio | |
# Get model directory path | |
# from espnet_model_zoo.downloader import ModelDownloader | |
# d = ModelDownloader() | |
# model_dir = os.path.dirname(d.download_and_unpack(tag)["train_config"]) | |
# Speaker x-vector selection | |
xvector_ark = [ | |
p | |
for p in glob.glob( | |
f"xvector/test-clean/spk_xvector.ark", recursive=True | |
) | |
if "test" in p | |
][0] | |
xvectors = {k: v for k, v in kaldiio.load_ark(xvector_ark)} | |
spks = list(xvectors.keys()) | |
male_spks = { | |
"Male1": "260_123286", | |
"Male2": "1320_122612", | |
"Male3": "672_122797" | |
} | |
female_spks = {"Female1": "5683_32865", | |
"Female2": "121_121726", | |
"Female3": "8463_287645"} | |
spks = dict(male_spks, **female_spks) | |
spk_names = sorted(spks.keys()) | |
def ASRTTS(audio_file, spk_name, ref_text=""): | |
spk = spks[spk_name] | |
spembs = xvectors[spk] | |
if ref_text == "": | |
reg_text = transcriber(audio_file)["text"] | |
else: | |
reg_text = ref_text | |
speech, sr = torchaudio.load( | |
audio_file, channels_first=True | |
) # Mono channel | |
wav_tensor_spembs = text2speech( | |
text=reg_text, speech=speech, spembs=spembs | |
)["wav"] | |
wav_numpy = wav_tensor_spembs.unsqueeze(1).to("cpu") | |
sample_rate = 22050 | |
save_id = ( | |
"./wav/" + Path(audio_file).stem + "_" + spk_name + "_spkembs.wav" | |
) | |
torchaudio.save( | |
save_id, | |
src=wav_tensor_spembs.unsqueeze(0).to("cpu"), | |
sample_rate=22050, | |
) | |
return save_id, reg_text | |
def ASRTTS_clean(audio_file, spk_name): | |
spk = spks[spk_name] | |
spembs = xvectors[spk] | |
reg_text = transcriber(audio_file)["text"] | |
speech, sr = torchaudio.load( | |
audio_file, channels_first=True | |
) # Mono channel | |
wav_tensor_spembs = text2speech( | |
text=reg_text, speech=speech, spembs=spembs | |
)["wav"] | |
wav_numpy = wav_tensor_spembs.unsqueeze(1).to("cpu") | |
sample_rate = 22050 | |
save_id = ( | |
"./wav/" + Path(audio_file).stem + "_" + spk_name + "_spkembs.wav" | |
) | |
torchaudio.save( | |
save_id, | |
src=wav_tensor_spembs.unsqueeze(0).to("cpu"), | |
sample_rate=22050, | |
) | |
return save_id | |
def ft2_ASRTTS_clean(audio_file, spk_name): | |
spk = spks[spk_name] | |
spembs = xvectors[spk] | |
reg_text = transcriber(audio_file)["text"] | |
speech, sr = torchaudio.load( | |
audio_file, channels_first=True | |
) # Mono channel | |
wav_tensor_spembs = ft2_text2speech( | |
text=reg_text, speech=speech, spembs=spembs | |
)["wav"] | |
wav_numpy = wav_tensor_spembs.unsqueeze(1).to("cpu") | |
sample_rate = 22050 | |
save_id = ( | |
"./wav/" + Path(audio_file).stem + "_fs2_" + spk_name + "_spkembs.wav" | |
) | |
torchaudio.save( | |
save_id, | |
src=wav_tensor_spembs.unsqueeze(0).to("cpu"), | |
sample_rate=22050, | |
) | |
return save_id | |
def ft2_ASRTTS_clean_hifi(audio_file, spk_name): | |
spk = spks[spk_name] | |
spembs = xvectors[spk] | |
reg_text = transcriber(audio_file)["text"] | |
speech, sr = torchaudio.load( | |
audio_file, channels_first=True | |
) # Mono channel | |
wav_tensor_spembs = ft2_text2speech_hifi( | |
text=reg_text, speech=speech, spembs=spembs | |
)["wav"] | |
wav_numpy = wav_tensor_spembs.unsqueeze(1).to("cpu") | |
sample_rate = 22050 | |
save_id = ( | |
"./wav/" + Path(audio_file).stem + "_fs2_hifi_" + spk_name + "_spkembs.wav" | |
) | |
torchaudio.save( | |
save_id, | |
src=wav_tensor_spembs.unsqueeze(0).to("cpu"), | |
sample_rate=22050, | |
) | |
return save_id | |
def transformer_ASRTTS_clean(audio_file, spk_name): | |
spk = spks[spk_name] | |
spembs = xvectors[spk] | |
reg_text = transcriber(audio_file)["text"] | |
speech, sr = torchaudio.load( | |
audio_file, channels_first=True | |
) # Mono channel | |
wav_tensor_spembs = transformer_text2speech( | |
text=reg_text, speech=speech, spembs=spembs | |
)["wav"] | |
wav_numpy = wav_tensor_spembs.unsqueeze(1).to("cpu") | |
sample_rate = 22050 | |
save_id = ( | |
"./wav/" + Path(audio_file).stem + "_transformer_" + spk_name + "_spkembs.wav" | |
) | |
torchaudio.save( | |
save_id, | |
src=wav_tensor_spembs.unsqueeze(0).to("cpu"), | |
sample_rate=22050, | |
) | |
return save_id | |
# def google_ASRTTS_clean(audio_file, spk_name): | |
# spk = spks[spk_name] | |
# spembs = xvectors[spk] | |
# reg_text = transcriber(audio_file)["text"] | |
# # pdb.set_trace() | |
# synthesis_input = texttospeech.SynthesisInput(text=reg_text) | |
# voice = texttospeech.VoiceSelectionParams( | |
# language_code="en-US", ssml_gender=texttospeech.SsmlVoiceGender.NEUTRAL | |
# ) | |
# audio_config = texttospeech.AudioConfig( | |
# audio_encoding=texttospeech.AudioEncoding.MP3 | |
# ) | |
# response = Google_TTS_client.synthesize_speech( | |
# input=synthesis_input, voice=voice, audio_config=audio_config | |
# ) | |
# save_id = ( | |
# "./wav/" + Path(audio_file).stem + "_google_" + spk_name + "_spkembs.wav" | |
# ) | |
# with open(save_id, "wb") as out_file: | |
# out_file.write(response.audio_content) | |
# return save_id | |
reference_textbox = gr.Textbox( | |
value="", | |
placeholder="Input reference here", | |
label="Reference", | |
) | |
recognization_textbox = gr.Textbox( | |
value="", | |
placeholder="Output recognization here", | |
label="recognization_textbox", | |
) | |
speaker_option = gr.Radio(choices=spk_names, label="Speaker") | |
input_audio = gr.Audio( | |
source="upload", type="filepath", label="Audio_to_Evaluate" | |
) | |
output_audio = gr.Audio( | |
source="upload", file="filepath", label="Synthesized Audio" | |
) | |
examples = [ | |
["./samples/001.wav", "M1", ""], | |
["./samples/002.wav", "M2", ""], | |
["./samples/003.wav", "F1", ""], | |
["./samples/004.wav", "F2", ""], | |
] | |
def change_audiobox(choice): | |
if choice == "upload": | |
input_audio = gr.Audio.update(source="upload", visible=True) | |
elif choice == "microphone": | |
input_audio = gr.Audio.update(source="microphone", visible=True) | |
else: | |
input_audio = gr.Audio.update(visible=False) | |
return input_audio | |
def show_icon(choice): | |
if choice == "Male1": | |
spk_icon = gr.Image.update(value="speaker_icons/male1.png", visible=True) | |
elif choice == "Male2": | |
spk_icon = gr.Image.update(value="speaker_icons/male2.png", visible=True) | |
elif choice == "Male3": | |
spk_icon = gr.Image.update(value="speaker_icons/male3.png", visible=True) | |
elif choice == "Female1": | |
spk_icon = gr.Image.update(value="speaker_icons/female1.png", visible=True) | |
elif choice == "Female2": | |
spk_icon = gr.Image.update(value="speaker_icons/female2.png", visible=True) | |
elif choice == "Female3": | |
spk_icon = gr.Image.update(value="speaker_icons/female3.png", visible=True) | |
return spk_icon | |
def get_download_file(audio_file=None): | |
if audio_file == None: | |
output_audio_file = gr.File.update(visible=False) | |
else: | |
output_audio_file = gr.File.update(visible=True) | |
return output_audio_file | |
def download_file(audio_file): | |
return gr.File(value=audio_file) | |
# pdb.set_trace() | |
with gr.Blocks( | |
analytics_enabled=False, | |
css=".gradio-container {background-color: #78BD91}", | |
) as demo: | |
# Public Version | |
with gr.Tab("Open Version"): | |
with gr.Column(elem_id="Column"): | |
input_format = gr.Radio( | |
choices=["microphone", "upload"], label="Choose your input format", elem_id="input_format" | |
) | |
input_audio = gr.Audio( | |
source="microphone", | |
type="filepath", | |
label="Input Audio", | |
interactive=True, | |
visible=False, | |
elem_id="input_audio" | |
) | |
input_format.change( | |
fn=change_audiobox, inputs=input_format, outputs=input_audio | |
) | |
speaker_option = gr.Radio(choices=spk_names, value="Male1", label="Choose your voice profile") | |
spk_icon = gr.Image(value="speaker_icons/male1.png", | |
type="filepath", | |
image_mode="RGB", | |
source="upload", | |
shape=[50, 50], | |
interactive=True, | |
visible=True) | |
speaker_option.change( | |
fn=show_icon, inputs=speaker_option, outputs=spk_icon | |
) | |
b = gr.Button("Convert") | |
output_audio = gr.Audio( | |
source="upload", file="filepath", label="Converted Audio", interactive=False | |
) | |
b.click( | |
ASRTTS_clean, | |
inputs=[input_audio, speaker_option], | |
outputs=output_audio, | |
api_name="convert" | |
) | |
# Tab selection: | |
with gr.Tab("Test Version: Multi TTS model"): | |
with gr.Column(elem_id="Column"): | |
input_format = gr.Radio( | |
choices=["microphone", "upload"], label="Choose your input format", elem_id="input_format" | |
) | |
input_audio = gr.Audio( | |
source="microphone", | |
type="filepath", | |
label="Input Audio", | |
interactive=True, | |
visible=False, | |
elem_id="input_audio" | |
) | |
input_format.change( | |
fn=change_audiobox, inputs=input_format, outputs=input_audio | |
) | |
speaker_option = gr.Radio(choices=spk_names, value="Male1", label="Choose your voice profile") | |
spk_icon = gr.Image(value="speaker_icons/male1.png", | |
type="filepath", | |
image_mode="RGB", | |
source="upload", | |
shape=[50, 50], | |
interactive=True, | |
visible=True) | |
speaker_option.change( | |
fn=show_icon, inputs=speaker_option, outputs=spk_icon | |
) | |
with gr.Column(): | |
with gr.Row(): | |
b2 = gr.Button("Convert") | |
output_audio = gr.Audio( | |
source="upload", file="filepath", label="Converted Audio", interactive=False | |
) | |
b2.click( | |
ASRTTS_clean, | |
inputs=[input_audio, speaker_option], | |
outputs=output_audio, | |
api_name="convert_" | |
) | |
with gr.Row(): | |
# Fastspeech2 + PWG [under construction] | |
b_ft2 = gr.Button("Convert_fastspeech2") | |
output_audio_ft2= gr.Audio( | |
source="upload", file="filepath", label="Converted Audio", interactive=False | |
) | |
b_ft2.click( | |
ft2_ASRTTS_clean, | |
inputs=[input_audio, speaker_option], | |
outputs=output_audio_ft2, | |
api_name="convert_ft2" | |
) | |
with gr.Row(): | |
# Fastspeech2 + hifigan [under construction] | |
b_ft2_hifi = gr.Button("Convert_fastspeech2+HifiGAN") | |
output_audio_ft2_hifi= gr.Audio( | |
source="upload", file="filepath", label="Converted Audio", interactive=False | |
) | |
b_ft2_hifi.click( | |
ft2_ASRTTS_clean_hifi, | |
inputs=[input_audio, speaker_option], | |
outputs=output_audio_ft2_hifi, | |
api_name="convert_ft2_hifi" | |
) | |
with gr.Row(): | |
# transformer [TODO] | |
b_transformer = gr.Button("Convert_transformer") | |
output_audio_transformer= gr.Audio( | |
source="upload", file="filepath", label="Converted Audio", interactive=False | |
) | |
b_transformer.click( | |
transformer_ASRTTS_clean, | |
inputs=[input_audio, speaker_option], | |
outputs=output_audio_transformer, | |
api_name="convert_trans" | |
) | |
# google tts [TODO] | |
# b_google = gr.Button("Convert_googleTTS") | |
# output_audio_google= gr.Audio( | |
# source="upload", file="filepath", label="Converted Audio", interactive=False | |
# ) | |
# b_google.click( | |
# google_ASRTTS_clean, | |
# inputs=[input_audio, speaker_option], | |
# outputs=output_audio_google, | |
# api_name="convert" | |
# ) | |
demo.launch(share=False) |