KevinGeng's picture
fix wavfile name
5b6add3
"""
TODO:
+ [x] Load Configuration
+ [ ] Checking
+ [ ] Better saving directory
"""
import numpy as np
from pathlib import Path
import torch.nn as nn
import torch
import torchaudio
from transformers import pipeline
from pathlib import Path
import datetime
import pdb
# local import
import sys
from espnet2.bin.tts_inference import Text2Speech
from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
sys.path.append("src")
import gradio as gr
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
processor = AutoProcessor.from_pretrained("KevinGeng/whipser_medium_en_PAL300_step25")
model = AutoModelForSpeechSeq2Seq.from_pretrained("KevinGeng/whipser_medium_en_PAL300_step25")
transcriber = pipeline("automatic-speech-recognition", model="KevinGeng/whipser_medium_en_PAL300_step25")
# Text2Mel models
# @title English multi-speaker pretrained model { run: "auto" }
lang = "English"
vits_tag = "kan-bayashi/libritts_xvector_vits"
ft2_tag = "kan-bayashi/libritts_xvector_conformer_fastspeech2"
transformer_tag = "kan-bayashi/libritts_xvector_transformer"
# !!! vits needs no vocoder !!!
# Local Text2Mel models
vits_config_local = "TTS_models/libritts_xvector_vits/config.yaml"
vits_model_local = "TTS_models/libritts_xvector_vits/train.total_count.ave_10best.pth"
# TODO
ft2_config_local = ""
ft2_model_local= ""
transformer_config_local = ""
transformer_config_local = ""
# Vocoders
vocoder_tag = "parallel_wavegan/vctk_parallel_wavegan.v1.long" # @param ["none", "parallel_wavegan/vctk_parallel_wavegan.v1.long", "parallel_wavegan/vctk_multi_band_melgan.v2", "parallel_wavegan/vctk_style_melgan.v1", "parallel_wavegan/vctk_hifigan.v1", "parallel_wavegan/libritts_parallel_wavegan.v1.long", "parallel_wavegan/libritts_multi_band_melgan.v2", "parallel_wavegan/libritts_hifigan.v1", "parallel_wavegan/libritts_style_melgan.v1"] {type:"string"}
hifigan_vocoder_tag = "parallel_wavegan/parallel_wavegan/libritts_hifigan.v1" # @param ["none", "parallel_wavegan/vctk_parallel_wavegan.v1.long", "parallel_wavegan/vctk_multi_band_melgan.v2", "parallel_wavegan/vctk_style_melgan.v1", "parallel_wavegan/vctk_hifigan.v1", "parallel_wavegan/libritts_parallel_wavegan.v1.long", "parallel_wavegan/libritts_multi_band_melgan.v2", "parallel_wavegan/libritts_hifigan.v1", "parallel_wavegan/libritts_style_melgan.v1"] {type:"string"}
# Local Vocoders
## Make sure the use parallel_wavegan as prefix (PWG feature)
vocoder_tag_local = "parallel_wavegan/vctk_parallel_wavegan.v1.long"
hifigan_vocoder_tag_local = "parallel_wavegan/libritts_hifigan.v1"
from espnet2.bin.tts_inference import Text2Speech
from espnet2.utils.types import str_or_none
# local import
text2speech = Text2Speech.from_pretrained(
train_config = vits_config_local,
model_file=vits_model_local,
device="cuda",
use_att_constraint=False,
backward_window=1,
forward_window=3,
speed_control_alpha=1.0,
)
# # Fastspeech2
# ft2_text2speech = Text2Speech.from_pretrained(
# model_tag=ft2_tag,
# vocoder_tag=str_or_none(vocoder_tag_local),
# device="cuda",
# use_att_constraint=False,
# backward_window=1,
# forward_window=3,
# speed_control_alpha=1.0,
# )
# # Fastspeech2 + hifigan
# ft2_text2speech_hifi = Text2Speech.from_pretrained(
# model_tag=ft2_tag,
# vocoder_tag=str_or_none(hifigan_vocoder_tag_local),
# device="cuda",
# use_att_constraint=False,
# backward_window=1,
# forward_window=3,
# speed_control_alpha=1.0,
# )
# # transformer tag
# transformer_text2speech = Text2Speech.from_pretrained(
# model_tag=transformer_tag,
# vocoder_tag=str_or_none(vocoder_tag_local),
# device="cuda",
# use_att_constraint=False,
# backward_window=1,
# forward_window=3,
# speed_control_alpha=1.0,
# )
import glob
import os
import numpy as np
import kaldiio
# Get model directory path
# from espnet_model_zoo.downloader import ModelDownloader
# d = ModelDownloader()
# model_dir = os.path.dirname(d.download_and_unpack(tag)["train_config"])
# Speaker x-vector selection
xvector_ark = [
p
for p in glob.glob(
f"xvector/test-clean/spk_xvector.ark", recursive=True
)
if "test" in p
][0]
xvectors = {k: v for k, v in kaldiio.load_ark(xvector_ark)}
spks = list(xvectors.keys())
male_spks = {
"Male1": "260_123286",
"Male2": "1320_122612",
"Male3": "672_122797"
}
female_spks = {"Female1": "5683_32865",
"Female2": "121_121726",
"Female3": "8463_287645"}
spks = dict(male_spks, **female_spks)
spk_names = sorted(spks.keys())
def ASRTTS(audio_file, spk_name, ref_text=""):
spk = spks[spk_name]
spembs = xvectors[spk]
if ref_text == "":
reg_text = transcriber(audio_file)["text"]
else:
reg_text = ref_text
speech, sr = torchaudio.load(
audio_file, channels_first=True
) # Mono channel
wav_tensor_spembs = text2speech(
text=reg_text, speech=speech, spembs=spembs
)["wav"]
wav_numpy = wav_tensor_spembs.unsqueeze(1).to("cpu")
sample_rate = 22050
save_id = (
"./wav/" + Path(audio_file).stem + "_" + spk_name + "_spkembs.wav"
)
torchaudio.save(
save_id,
src=wav_tensor_spembs.unsqueeze(0).to("cpu"),
sample_rate=22050,
)
return save_id, reg_text
def ASRTTS_clean(audio_file, spk_name):
spk = spks[spk_name]
spembs = xvectors[spk]
reg_text = transcriber(audio_file)["text"]
speech, sr = torchaudio.load(
audio_file, channels_first=True
) # Mono channel
wav_tensor_spembs = text2speech(
text=reg_text, speech=speech, spembs=spembs
)["wav"]
wav_numpy = wav_tensor_spembs.unsqueeze(1).to("cpu")
sample_rate = 22050
# create another saving id wav file with current time stamp in YYYYMMDD_HHMMSS format
save_id = (
"./wav/" + str(datetime.datetime.now().strftime("%Y%m%d_%H%M%S")) + ".wav"
)
torchaudio.save(
save_id,
src=wav_tensor_spembs.unsqueeze(0).to("cpu"),
sample_rate=22050,
)
return save_id
def ft2_ASRTTS_clean(audio_file, spk_name):
spk = spks[spk_name]
spembs = xvectors[spk]
reg_text = transcriber(audio_file)["text"]
speech, sr = torchaudio.load(
audio_file, channels_first=True
) # Mono channel
wav_tensor_spembs = ft2_text2speech(
text=reg_text, speech=speech, spembs=spembs
)["wav"]
wav_numpy = wav_tensor_spembs.unsqueeze(1).to("cpu")
sample_rate = 22050
save_id = (
"./wav/" + Path(audio_file).stem + "_fs2_" + spk_name + "_spkembs.wav"
)
torchaudio.save(
save_id,
src=wav_tensor_spembs.unsqueeze(0).to("cpu"),
sample_rate=22050,
)
return save_id
def ft2_ASRTTS_clean_hifi(audio_file, spk_name):
spk = spks[spk_name]
spembs = xvectors[spk]
reg_text = transcriber(audio_file)["text"]
speech, sr = torchaudio.load(
audio_file, channels_first=True
) # Mono channel
wav_tensor_spembs = ft2_text2speech_hifi(
text=reg_text, speech=speech, spembs=spembs
)["wav"]
wav_numpy = wav_tensor_spembs.unsqueeze(1).to("cpu")
sample_rate = 22050
save_id = (
"./wav/" + Path(audio_file).stem + "_fs2_hifi_" + spk_name + "_spkembs.wav"
)
torchaudio.save(
save_id,
src=wav_tensor_spembs.unsqueeze(0).to("cpu"),
sample_rate=22050,
)
return save_id
def transformer_ASRTTS_clean(audio_file, spk_name):
spk = spks[spk_name]
spembs = xvectors[spk]
reg_text = transcriber(audio_file)["text"]
speech, sr = torchaudio.load(
audio_file, channels_first=True
) # Mono channel
wav_tensor_spembs = transformer_text2speech(
text=reg_text, speech=speech, spembs=spembs
)["wav"]
wav_numpy = wav_tensor_spembs.unsqueeze(1).to("cpu")
sample_rate = 22050
save_id = (
"./wav/" + Path(audio_file).stem + "_transformer_" + spk_name + "_spkembs.wav"
)
torchaudio.save(
save_id,
src=wav_tensor_spembs.unsqueeze(0).to("cpu"),
sample_rate=22050,
)
return save_id
# def google_ASRTTS_clean(audio_file, spk_name):
# spk = spks[spk_name]
# spembs = xvectors[spk]
# reg_text = transcriber(audio_file)["text"]
# # pdb.set_trace()
# synthesis_input = texttospeech.SynthesisInput(text=reg_text)
# voice = texttospeech.VoiceSelectionParams(
# language_code="en-US", ssml_gender=texttospeech.SsmlVoiceGender.NEUTRAL
# )
# audio_config = texttospeech.AudioConfig(
# audio_encoding=texttospeech.AudioEncoding.MP3
# )
# response = Google_TTS_client.synthesize_speech(
# input=synthesis_input, voice=voice, audio_config=audio_config
# )
# save_id = (
# "./wav/" + Path(audio_file).stem + "_google_" + spk_name + "_spkembs.wav"
# )
# with open(save_id, "wb") as out_file:
# out_file.write(response.audio_content)
# return save_id
reference_textbox = gr.Textbox(
value="",
placeholder="Input reference here",
label="Reference",
)
recognization_textbox = gr.Textbox(
value="",
placeholder="Output recognization here",
label="recognization_textbox",
)
speaker_option = gr.Radio(choices=spk_names, label="Speaker")
input_audio = gr.Audio(
source="upload", type="filepath", label="Audio_to_Evaluate"
)
output_audio = gr.Audio(
source="upload", file="filepath", label="Synthesized Audio"
)
examples = [
["./samples/001.wav", "M1", ""],
["./samples/002.wav", "M2", ""],
["./samples/003.wav", "F1", ""],
["./samples/004.wav", "F2", ""],
]
def change_audiobox(choice):
if choice == "upload":
input_audio = gr.Audio(source="upload", visible=True)
elif choice == "microphone":
input_audio = gr.Audio(source="microphone", visible=True)
else:
input_audio = gr.Audio(visible=False)
return input_audio
def show_icon(choice):
if choice == "Male1":
spk_icon = gr.Image.update(value="speaker_icons/male1.png", visible=True)
elif choice == "Male2":
spk_icon = gr.Image.update(value="speaker_icons/male2.png", visible=True)
elif choice == "Male3":
spk_icon = gr.Image.update(value="speaker_icons/male3.png", visible=True)
elif choice == "Female1":
spk_icon = gr.Image.update(value="speaker_icons/female1.png", visible=True)
elif choice == "Female2":
spk_icon = gr.Image.update(value="speaker_icons/female2.png", visible=True)
elif choice == "Female3":
spk_icon = gr.Image.update(value="speaker_icons/female3.png", visible=True)
return spk_icon
def get_download_file(audio_file=None):
if audio_file == None:
output_audio_file = gr.File.update(visible=False)
else:
output_audio_file = gr.File.update(visible=True)
return output_audio_file
def download_file(audio_file):
return gr.File(value=audio_file)
# pdb.set_trace()
with gr.Blocks(
analytics_enabled=False,
css=".gradio-container {background-color: #78BD91}",
) as demo:
# Public Version
with gr.Tab("Open Version"):
with gr.Column(elem_id="Column"):
input_format = gr.Radio(
choices=["microphone", "upload"], label="Choose your input format", elem_id="input_format"
)
input_audio = gr.Audio(
source="microphone",
type="filepath",
label="Input Audio",
interactive=True,
visible=False,
elem_id="input_audio"
)
input_format.change(
fn=change_audiobox, inputs=input_format, outputs=input_audio
)
speaker_option = gr.Radio(choices=spk_names, value="Male1", label="Choose your voice profile")
spk_icon = gr.Image(value="speaker_icons/male1.png",
type="filepath",
image_mode="RGB",
source="upload",
shape=[50, 50],
interactive=True,
visible=True)
speaker_option.change(
fn=show_icon, inputs=speaker_option, outputs=spk_icon
)
b = gr.Button("Convert")
output_audio = gr.Audio(
source="upload", file="filepath", label="Converted Audio", interactive=False
)
b.click(
ASRTTS_clean,
inputs=[input_audio, speaker_option],
outputs=output_audio,
api_name="convert"
)
# # Tab selection:
# with gr.Tab("Test Version: Multi TTS model"):
# with gr.Column(elem_id="Column"):
# input_format = gr.Radio(
# choices=["microphone", "upload"], label="Choose your input format", elem_id="input_format"
# )
# input_audio = gr.Audio(
# source="microphone",
# type="filepath",
# label="Input Audio",
# interactive=True,
# visible=False,
# elem_id="input_audio"
# )
# input_format.change(
# fn=change_audiobox, inputs=input_format, outputs=input_audio
# )
# speaker_option = gr.Radio(choices=spk_names, value="Male1", label="Choose your voice profile")
# spk_icon = gr.Image(value="speaker_icons/male1.png",
# type="filepath",
# image_mode="RGB",
# source="upload",
# shape=[50, 50],
# interactive=True,
# visible=True)
# speaker_option.change(
# fn=show_icon, inputs=speaker_option, outputs=spk_icon
# )
# with gr.Column():
# with gr.Row():
# b2 = gr.Button("Convert")
# output_audio = gr.Audio(
# source="upload", file="filepath", label="Converted Audio", interactive=False
# )
# b2.click(
# ASRTTS_clean,
# inputs=[input_audio, speaker_option],
# outputs=output_audio,
# api_name="convert_"
# )
# with gr.Row():
# # Fastspeech2 + PWG [under construction]
# b_ft2 = gr.Button("Convert_fastspeech2")
# output_audio_ft2= gr.Audio(
# source="upload", file="filepath", label="Converted Audio", interactive=False
# )
# b_ft2.click(
# ft2_ASRTTS_clean,
# inputs=[input_audio, speaker_option],
# outputs=output_audio_ft2,
# api_name="convert_ft2"
# )
# with gr.Row():
# # Fastspeech2 + hifigan [under construction]
# b_ft2_hifi = gr.Button("Convert_fastspeech2+HifiGAN")
# output_audio_ft2_hifi= gr.Audio(
# source="upload", file="filepath", label="Converted Audio", interactive=False
# )
# b_ft2_hifi.click(
# ft2_ASRTTS_clean_hifi,
# inputs=[input_audio, speaker_option],
# outputs=output_audio_ft2_hifi,
# api_name="convert_ft2_hifi"
# )
# with gr.Row():
# # transformer [TODO]
# b_transformer = gr.Button("Convert_transformer")
# output_audio_transformer= gr.Audio(
# source="upload", file="filepath", label="Converted Audio", interactive=False
# )
# b_transformer.click(
# transformer_ASRTTS_clean,
# inputs=[input_audio, speaker_option],
# outputs=output_audio_transformer,
# api_name="convert_trans"
# )
# google tts [TODO]
# b_google = gr.Button("Convert_googleTTS")
# output_audio_google= gr.Audio(
# source="upload", file="filepath", label="Converted Audio", interactive=False
# )
# b_google.click(
# google_ASRTTS_clean,
# inputs=[input_audio, speaker_option],
# outputs=output_audio_google,
# api_name="convert"
# )
demo.launch(share=False)