Spaces:
Runtime error
Runtime error
File size: 4,876 Bytes
41989ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import streamlit as st
import soundfile as sf
import timeit
import uuid
import os
import torch
from datautils import *
from model import Generator as Glow_model
from utils import scan_checkpoint, plot_mel, plot_alignment
from Hmodel import Generator as GAN_model
MAX_WAV_VALUE = 32768.0
device = torch.device('cuda:0')
torch.cuda.manual_seed(1234)
name = '1038_eunsik_01'
# Nix
from nix.models.TTS import NixTTSInference
def init_session_state():
# Model
if "init_model" not in st.session_state:
st.session_state.init_model = True
st.session_state.model_variant = "KSS"
st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-sdp-v0.1")
def update_model():
if st.session_state.model_variant == "KSS":
st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-v0.1")
elif st.session_state.model_variant == "μμ":
st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-sdp-v0.1")
def update_session_state(state_id, state_value):
st.session_state[f"{state_id}"] = state_value
def centered_text(input_text, mode = "h1",):
st.markdown(
f"<{mode} style='text-align: center;'>{input_text}</{mode}>", unsafe_allow_html = True)
def generate_voice(input_text,):
# TTS Inference
c, c_length, phoneme = st.session_state.TTS.tokenize(input_text)
voice = st.session_state.TTS.vocalize(c, c_length)
# Save audio (bug in Streamlit, can't play numpy array directly)
sf.write(f"cache_sound/{input_text}.wav", voice[0,0], 22050)
# Play audio
st.audio(f"cache_sound/{input_text}.wav", format = "audio/wav")
os.remove(f"cache_sound/{input_text}.wav")
st.caption("Generated Voice")
st.set_page_config(
page_title = "μμ Team Demo",
page_icon = "π",
)
init_session_state()
centered_text("π μμ Team Demo")
centered_text("mel generator : Glow-TTS, vocoder : HiFi-GAN", "h5")
st.write(" ")
mode = "p"
st.markdown(
f"<{mode} style='text-align: left;'><small>This is a demo trained by our vocie. The voice \"KSS\" is traind 3 times \"μμ\" is finetuned from \"KSS\" for 3 times We got this deomoformat from Nix-TTS Interactive Demo</small></{mode}>",
unsafe_allow_html = True
)
st.write(" ")
st.write(" ")
col1, col2 = st.columns(2)
with col1:
input_text = st.text_input(
"νκΈλ‘λ§ μ
λ ₯ν΄μ£ΌμΈμ",
value = "λ₯λ¬λμ μ λ§ μ¬λ°μ΄!",
)
with col2:
model_variant = st.selectbox("λͺ©μ리 μ νν΄μ£ΌμΈμ", options = ["KSS", "μμ"], index = 1)
if model_variant != st.session_state.model_variant:
# Update variant choice
update_session_state("model_variant", model_variant)
# Re-load model
update_model()
button_gen = st.button("Generate Voice")
if button_gen == True:
generate_voice(input_text)
class TTS:
def __init__(self, model_variant):
self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6)
self.voicegenerator = GAN_model()
if model_variant == 'μμ':
last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
check_point = torch.load(last_chpt1)
self.flowgenerator.load_state_dict(check_point['generator'])
self.flowgenerator.decoder.skip()
self.flowgenerator.eval()
if model_variant == 'μμ':
last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
check_point = torch.load(last_chpt2)
self.voicegenerator.load_state_dict(check_point['gen_model'])
self.voicegenerator.eval()
self.voicegenerator.remove_weight_norm()
def inference(self, input_text):
x = text_to_sequence(sentence)
filters = '([.,!?])'
sentence = re.sub(re.compile(filters), '', text)
x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
with torch.no_grad():
noise_scale = .667
length_scale = 1.0
(y_gen_tst, *_), *_, (attn_gen, *_) = flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale)
y = voicegenerator(y_gen_tst)
audio = y.squeeze() * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype('int16')
output_file = os.path.join(out_dir, 'gen_'+text[:3]+'.wav')
write(output_file, 22050, audio)
print(f'{text} is stored in {out_dir}')
return voice
plot_mel(y_gen_tst[0].data.cpu().numpy())
plot_alignment(attn_gen[0,0].data.cpu().numpy(), sequence_to_text(x[0].data.cpu().numpy()))
ipd.display(fig1,fig2)
ipd.Audio(filename=output_file) |