Spaces:
Runtime error
Runtime error
Commit
Β·
8fde97d
1
Parent(s):
39097f7
Update app.py (#1)
Browse files- Update app.py (e44a7c2616a6874a8d182dd22a85c0dbe0843f58)
app.py
CHANGED
@@ -1,37 +1,61 @@
|
|
1 |
import streamlit as st
|
2 |
import soundfile as sf
|
3 |
-
import
|
4 |
-
import uuid
|
5 |
-
|
6 |
-
import os
|
7 |
-
|
8 |
import torch
|
9 |
-
|
10 |
from datautils import *
|
11 |
from model import Generator as Glow_model
|
12 |
-
from utils import scan_checkpoint, plot_mel, plot_alignment
|
13 |
from Hmodel import Generator as GAN_model
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
torch.cuda.manual_seed(1234)
|
18 |
-
name = '1038_eunsik_01'
|
19 |
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def init_session_state():
|
24 |
# Model
|
25 |
if "init_model" not in st.session_state:
|
26 |
st.session_state.init_model = True
|
27 |
-
st.session_state.model_variant = "
|
28 |
-
st.session_state.TTS =
|
29 |
|
30 |
def update_model():
|
31 |
if st.session_state.model_variant == "KSS":
|
32 |
-
st.session_state.TTS =
|
33 |
elif st.session_state.model_variant == "μμ":
|
34 |
-
st.session_state.TTS =
|
35 |
|
36 |
def update_session_state(state_id, state_value):
|
37 |
st.session_state[f"{state_id}"] = state_value
|
@@ -40,19 +64,19 @@ def centered_text(input_text, mode = "h1",):
|
|
40 |
st.markdown(
|
41 |
f"<{mode} style='text-align: center;'>{input_text}</{mode}>", unsafe_allow_html = True)
|
42 |
|
43 |
-
def generate_voice(input_text
|
44 |
# TTS Inference
|
45 |
-
|
46 |
-
voice = st.session_state.TTS.vocalize(c, c_length)
|
47 |
|
48 |
# Save audio (bug in Streamlit, can't play numpy array directly)
|
49 |
-
sf.write(f"cache_sound/{input_text}.wav", voice
|
50 |
|
51 |
# Play audio
|
52 |
st.audio(f"cache_sound/{input_text}.wav", format = "audio/wav")
|
53 |
os.remove(f"cache_sound/{input_text}.wav")
|
54 |
st.caption("Generated Voice")
|
55 |
-
|
|
|
56 |
st.set_page_config(
|
57 |
page_title = "μμ Team Demo",
|
58 |
page_icon = "π",
|
@@ -92,44 +116,3 @@ if button_gen == True:
|
|
92 |
generate_voice(input_text)
|
93 |
|
94 |
|
95 |
-
class TTS:
|
96 |
-
def __init__(self, model_variant):
|
97 |
-
self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6)
|
98 |
-
self.voicegenerator = GAN_model()
|
99 |
-
if model_variant == 'μμ':
|
100 |
-
last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
|
101 |
-
check_point = torch.load(last_chpt1)
|
102 |
-
self.flowgenerator.load_state_dict(check_point['generator'])
|
103 |
-
self.flowgenerator.decoder.skip()
|
104 |
-
self.flowgenerator.eval()
|
105 |
-
if model_variant == 'μμ':
|
106 |
-
last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
|
107 |
-
check_point = torch.load(last_chpt2)
|
108 |
-
self.voicegenerator.load_state_dict(check_point['gen_model'])
|
109 |
-
self.voicegenerator.eval()
|
110 |
-
self.voicegenerator.remove_weight_norm()
|
111 |
-
|
112 |
-
def inference(self, input_text):
|
113 |
-
x = text_to_sequence(sentence)
|
114 |
-
filters = '([.,!?])'
|
115 |
-
sentence = re.sub(re.compile(filters), '', text)
|
116 |
-
x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
|
117 |
-
x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
|
118 |
-
|
119 |
-
with torch.no_grad():
|
120 |
-
noise_scale = .667
|
121 |
-
length_scale = 1.0
|
122 |
-
(y_gen_tst, *_), *_, (attn_gen, *_) = flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale)
|
123 |
-
y = voicegenerator(y_gen_tst)
|
124 |
-
audio = y.squeeze() * MAX_WAV_VALUE
|
125 |
-
audio = audio.cpu().numpy().astype('int16')
|
126 |
-
|
127 |
-
output_file = os.path.join(out_dir, 'gen_'+text[:3]+'.wav')
|
128 |
-
write(output_file, 22050, audio)
|
129 |
-
print(f'{text} is stored in {out_dir}')
|
130 |
-
|
131 |
-
return voice
|
132 |
-
plot_mel(y_gen_tst[0].data.cpu().numpy())
|
133 |
-
plot_alignment(attn_gen[0,0].data.cpu().numpy(), sequence_to_text(x[0].data.cpu().numpy()))
|
134 |
-
ipd.display(fig1,fig2)
|
135 |
-
ipd.Audio(filename=output_file)
|
|
|
1 |
import streamlit as st
|
2 |
import soundfile as sf
|
3 |
+
import os, re
|
|
|
|
|
|
|
|
|
4 |
import torch
|
|
|
5 |
from datautils import *
|
6 |
from model import Generator as Glow_model
|
|
|
7 |
from Hmodel import Generator as GAN_model
|
8 |
|
9 |
+
device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
|
10 |
+
torch.cuda.manual_seed(1234) if torch.duda.is_available() else None
|
|
|
|
|
11 |
|
12 |
+
class TTS:
|
13 |
+
def __init__(self, model_variant):
|
14 |
+
self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6)
|
15 |
+
self.voicegenerator = GAN_model()
|
16 |
+
if model_variant == 'μμ':
|
17 |
+
name = '1038_eunsik_01'
|
18 |
+
last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
|
19 |
+
check_point = torch.load(last_chpt1)
|
20 |
+
self.flowgenerator.load_state_dict(check_point['generator'])
|
21 |
+
self.flowgenerator.decoder.skip()
|
22 |
+
self.flowgenerator.eval()
|
23 |
+
if model_variant == 'μμ':
|
24 |
+
name = '1038_eunsik_01'
|
25 |
+
last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
|
26 |
+
check_point = torch.load(last_chpt2)
|
27 |
+
self.voicegenerator.load_state_dict(check_point['gen_model'])
|
28 |
+
self.voicegenerator.eval()
|
29 |
+
self.voicegenerator.remove_weight_norm()
|
30 |
+
|
31 |
+
def inference(self, input_text):
|
32 |
+
filters = '([.,!?])'
|
33 |
+
sentence = re.sub(re.compile(filters), '', input_text)
|
34 |
+
x = text_to_sequence(sentence)
|
35 |
+
x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
|
36 |
+
x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
|
37 |
+
|
38 |
+
with torch.no_grad():
|
39 |
+
noise_scale = .667
|
40 |
+
length_scale = 1.0
|
41 |
+
(y_gen_tst, *_), *_, (attn_gen, *_) = self.flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale)
|
42 |
+
y = self.voicegenerator(y_gen_tst)
|
43 |
+
audio = y.squeeze() * 32768.0
|
44 |
+
voice = audio.cpu().numpy().astype('int16')
|
45 |
+
return voice
|
46 |
|
47 |
def init_session_state():
|
48 |
# Model
|
49 |
if "init_model" not in st.session_state:
|
50 |
st.session_state.init_model = True
|
51 |
+
st.session_state.model_variant = "μμ"
|
52 |
+
st.session_state.TTS = TTS("μμ")
|
53 |
|
54 |
def update_model():
|
55 |
if st.session_state.model_variant == "KSS":
|
56 |
+
st.session_state.TTS = TTS("KSS")
|
57 |
elif st.session_state.model_variant == "μμ":
|
58 |
+
st.session_state.TTS = TTS("μμ")
|
59 |
|
60 |
def update_session_state(state_id, state_value):
|
61 |
st.session_state[f"{state_id}"] = state_value
|
|
|
64 |
st.markdown(
|
65 |
f"<{mode} style='text-align: center;'>{input_text}</{mode}>", unsafe_allow_html = True)
|
66 |
|
67 |
+
def generate_voice(input_text):
|
68 |
# TTS Inference
|
69 |
+
voice = st.session_state.TTS.inference(input_text)
|
|
|
70 |
|
71 |
# Save audio (bug in Streamlit, can't play numpy array directly)
|
72 |
+
sf.write(f"cache_sound/{input_text}.wav", voice, 22050)
|
73 |
|
74 |
# Play audio
|
75 |
st.audio(f"cache_sound/{input_text}.wav", format = "audio/wav")
|
76 |
os.remove(f"cache_sound/{input_text}.wav")
|
77 |
st.caption("Generated Voice")
|
78 |
+
|
79 |
+
|
80 |
st.set_page_config(
|
81 |
page_title = "μμ Team Demo",
|
82 |
page_icon = "π",
|
|
|
116 |
generate_voice(input_text)
|
117 |
|
118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|