djkesu commited on
Commit
4408097
1 Parent(s): 145a5f9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +306 -0
app.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AGPL: a notification must be added stating that changes have been made to that file.
2
+
3
+ import os
4
+ import shutil
5
+ from pathlib import Path
6
+
7
+ import streamlit as st
8
+ from random import randint
9
+
10
+ from tortoise.api import MODELS_DIR
11
+ from tortoise.inference import (
12
+ infer_on_texts,
13
+ run_and_save_tts,
14
+ split_and_recombine_text,
15
+ )
16
+ from tortoise.utils.diffusion import SAMPLERS
17
+ from app_utils.filepicker import st_file_selector
18
+ from app_utils.conf import TortoiseConfig
19
+
20
+ from app_utils.funcs import (
21
+ timeit,
22
+ load_model,
23
+ list_voices,
24
+ load_voice_conditionings,
25
+ )
26
+
27
+
28
+ LATENT_MODES = [
29
+ "Tortoise original (bad)",
30
+ "average per 4.27s (broken on small files)",
31
+ "average per voice file (broken on small files)",
32
+ ]
33
+
34
+ def main():
35
+ conf = TortoiseConfig()
36
+
37
+ with st.expander("Create New Voice", expanded=True):
38
+ if "file_uploader_key" not in st.session_state:
39
+ st.session_state["file_uploader_key"] = str(randint(1000, 100000000))
40
+ st.session_state["text_input_key"] = str(randint(1000, 100000000))
41
+
42
+ uploaded_files = st.file_uploader(
43
+ "Upload Audio Samples for a New Voice",
44
+ accept_multiple_files=True,
45
+ type=["wav"],
46
+ key=st.session_state["file_uploader_key"]
47
+ )
48
+
49
+ voice_name = st.text_input(
50
+ "New Voice Name",
51
+ help="Enter a name for your new voice.",
52
+ value="",
53
+ key=st.session_state["text_input_key"]
54
+ )
55
+
56
+ create_voice_button = st.button(
57
+ "Create Voice",
58
+ disabled = ((voice_name.strip() == "") | (len(uploaded_files) == 0))
59
+ )
60
+ if create_voice_button:
61
+ st.write(st.session_state)
62
+ with st.spinner(f"Creating new voice: {voice_name}"):
63
+ new_voice_name = voice_name.strip().replace(" ", "_")
64
+
65
+ voices_dir = f'./tortoise/voices/{new_voice_name}/'
66
+ if os.path.exists(voices_dir):
67
+ shutil.rmtree(voices_dir)
68
+ os.makedirs(voices_dir)
69
+
70
+ for index, uploaded_file in enumerate(uploaded_files):
71
+ bytes_data = uploaded_file.read()
72
+ with open(f"{voices_dir}voice_sample{index}.wav", "wb") as wav_file:
73
+ wav_file.write(bytes_data)
74
+
75
+ st.session_state["text_input_key"] = str(randint(1000, 100000000))
76
+ st.session_state["file_uploader_key"] = str(randint(1000, 100000000))
77
+ st.experimental_rerun()
78
+
79
+ text = st.text_area(
80
+ "Text",
81
+ help="Text to speak.",
82
+ value="The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them.",
83
+ )
84
+
85
+ voices = [v for v in os.listdir("tortoise/voices") if v != "cond_latent_example"]
86
+
87
+ voice = st.selectbox(
88
+ "Voice",
89
+ voices,
90
+ help="Selects the voice to use for generation. See options in voices/ directory (and add your own!) "
91
+ "Use the & character to join two voices together. Use a comma to perform inference on multiple voices.",
92
+ index=0,
93
+ )
94
+ preset = st.selectbox(
95
+ "Preset",
96
+ (
97
+ "single_sample",
98
+ "ultra_fast",
99
+ "very_fast",
100
+ "ultra_fast_old",
101
+ "fast",
102
+ "standard",
103
+ "high_quality",
104
+ ),
105
+ help="Which voice preset to use.",
106
+ index=1,
107
+ )
108
+ with st.expander("Advanced"):
109
+ col1, col2 = st.columns(2)
110
+ with col1:
111
+ """#### Model parameters"""
112
+ candidates = st.number_input(
113
+ "Candidates",
114
+ help="How many output candidates to produce per-voice.",
115
+ value=1,
116
+ )
117
+ latent_averaging_mode = st.radio(
118
+ "Latent averaging mode",
119
+ LATENT_MODES,
120
+ help="How voice samples should be averaged together.",
121
+ index=0,
122
+ )
123
+ sampler = st.radio(
124
+ "Sampler",
125
+ #SAMPLERS,
126
+ ["dpm++2m", "p", "ddim"],
127
+ help="Diffusion sampler. Note that dpm++2m is experimental and typically requires more steps.",
128
+ index=1,
129
+ )
130
+ steps = st.number_input(
131
+ "Steps",
132
+ help="Override the steps used for diffusion (default depends on preset)",
133
+ value=10,
134
+ )
135
+ seed = st.number_input(
136
+ "Seed",
137
+ help="Random seed which can be used to reproduce results.",
138
+ value=-1,
139
+ )
140
+ if seed == -1:
141
+ seed = None
142
+ voice_fixer = st.checkbox(
143
+ "Voice fixer",
144
+ help="Use `voicefixer` to improve audio quality. This is a post-processing step which can be applied to any output.",
145
+ value=True,
146
+ )
147
+ """#### Directories"""
148
+ output_path = st.text_input(
149
+ "Output Path", help="Where to store outputs.", value="results/"
150
+ )
151
+
152
+ with col2:
153
+ """#### Optimizations"""
154
+ high_vram = not st.checkbox(
155
+ "Low VRAM",
156
+ help="Re-enable default offloading behaviour of tortoise",
157
+ value=True,
158
+ )
159
+ half = st.checkbox(
160
+ "Half-Precision",
161
+ help="Enable autocast to half precision for autoregressive model",
162
+ value=False,
163
+ )
164
+ kv_cache = st.checkbox(
165
+ "Key-Value Cache",
166
+ help="Enable kv_cache usage, leading to drastic speedups but worse memory usage",
167
+ value=True,
168
+ )
169
+ cond_free = st.checkbox(
170
+ "Conditioning Free",
171
+ help="Force conditioning free diffusion",
172
+ value=True,
173
+ )
174
+ no_cond_free = st.checkbox(
175
+ "Force Not Conditioning Free",
176
+ help="Force disable conditioning free diffusion",
177
+ value=False,
178
+ )
179
+
180
+ """#### Text Splitting"""
181
+ min_chars_to_split = st.number_input(
182
+ "Min Chars to Split",
183
+ help="Minimum number of characters to split text on",
184
+ min_value=50,
185
+ value=200,
186
+ step=1,
187
+ )
188
+
189
+ """#### Debug"""
190
+ produce_debug_state = st.checkbox(
191
+ "Produce Debug State",
192
+ help="Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.",
193
+ value=True,
194
+ )
195
+
196
+ ar_checkpoint = "."
197
+ diff_checkpoint = "."
198
+ if st.button("Update Basic Settings"):
199
+ conf.update(
200
+ EXTRA_VOICES_DIR=extra_voices_dir,
201
+ LOW_VRAM=not high_vram,
202
+ AR_CHECKPOINT=ar_checkpoint,
203
+ DIFF_CHECKPOINT=diff_checkpoint,
204
+ )
205
+
206
+ ar_checkpoint = None
207
+ diff_checkpoint = None
208
+ tts = load_model(MODELS_DIR, high_vram, kv_cache, ar_checkpoint, diff_checkpoint)
209
+
210
+ if st.button("Start"):
211
+ assert latent_averaging_mode
212
+ assert preset
213
+ assert voice
214
+
215
+ def show_generation(fp, filename: str):
216
+ """
217
+ audio_buffer = BytesIO()
218
+ save_gen_with_voicefix(g, audio_buffer, squeeze=False)
219
+ torchaudio.save(audio_buffer, g, 24000, format='wav')
220
+ """
221
+ st.audio(str(fp), format="audio/wav")
222
+ st.download_button(
223
+ "Download sample",
224
+ str(fp),
225
+ file_name=filename, # this doesn't actually seem to work lol
226
+ )
227
+
228
+ with st.spinner(
229
+ f"Generating {candidates} candidates for voice {voice} (seed={seed}). You can see progress in the terminal"
230
+ ):
231
+ os.makedirs(output_path, exist_ok=True)
232
+
233
+ selected_voices = voice.split(",")
234
+ for k, selected_voice in enumerate(selected_voices):
235
+ if "&" in selected_voice:
236
+ voice_sel = selected_voice.split("&")
237
+ else:
238
+ voice_sel = [selected_voice]
239
+ voice_samples, conditioning_latents = load_voice_conditionings(
240
+ voice_sel, []
241
+ )
242
+
243
+ voice_path = Path(os.path.join(output_path, selected_voice))
244
+
245
+ with timeit(
246
+ f"Generating {candidates} candidates for voice {selected_voice} (seed={seed})"
247
+ ):
248
+ nullable_kwargs = {
249
+ k: v
250
+ for k, v in zip(
251
+ ["sampler", "diffusion_iterations", "cond_free"],
252
+ [sampler, steps, cond_free],
253
+ )
254
+ if v is not None
255
+ }
256
+
257
+ def call_tts(text: str):
258
+ return tts.tts_with_preset(
259
+ text,
260
+ k=candidates,
261
+ voice_samples=voice_samples,
262
+ conditioning_latents=conditioning_latents,
263
+ preset=preset,
264
+ use_deterministic_seed=seed,
265
+ return_deterministic_state=True,
266
+ cvvp_amount=0.0,
267
+ half=half,
268
+ latent_averaging_mode=LATENT_MODES.index(
269
+ latent_averaging_mode
270
+ ),
271
+ **nullable_kwargs,
272
+ )
273
+
274
+ if len(text) < min_chars_to_split:
275
+ filepaths = run_and_save_tts(
276
+ call_tts,
277
+ text,
278
+ voice_path,
279
+ return_deterministic_state=True,
280
+ return_filepaths=True,
281
+ voicefixer=voice_fixer,
282
+ )
283
+ for i, fp in enumerate(filepaths):
284
+ show_generation(fp, f"{selected_voice}-text-{i}.wav")
285
+ else:
286
+ desired_length = int(min_chars_to_split)
287
+ texts = split_and_recombine_text(
288
+ text, desired_length, desired_length + 100
289
+ )
290
+ filepaths = infer_on_texts(
291
+ call_tts,
292
+ texts,
293
+ voice_path,
294
+ return_deterministic_state=True,
295
+ return_filepaths=True,
296
+ lines_to_regen=set(range(len(texts))),
297
+ voicefixer=voice_fixer,
298
+ )
299
+ for i, fp in enumerate(filepaths):
300
+ show_generation(fp, f"{selected_voice}-text-{i}.wav")
301
+ if produce_debug_state:
302
+ """Debug states can be found in the output directory"""
303
+
304
+
305
+ if __name__ == "__main__":
306
+ main()