djkesu commited on
Commit
d04d508
1 Parent(s): 1e5b6b6

Update scripts/app.py

Browse files
Files changed (1) hide show
  1. scripts/app.py +0 -306
scripts/app.py CHANGED
@@ -1,306 +0,0 @@
1
- # AGPL: a notification must be added stating that changes have been made to that file.
2
-
3
- import os
4
- import shutil
5
- from pathlib import Path
6
-
7
- import streamlit as st
8
- from random import randint
9
-
10
- from tortoise.api import MODELS_DIR
11
- from tortoise.inference import (
12
- infer_on_texts,
13
- run_and_save_tts,
14
- split_and_recombine_text,
15
- )
16
- from tortoise.utils.diffusion import SAMPLERS
17
- from app_utils.filepicker import st_file_selector
18
- from app_utils.conf import TortoiseConfig
19
-
20
- from app_utils.funcs import (
21
- timeit,
22
- load_model,
23
- list_voices,
24
- load_voice_conditionings,
25
- )
26
-
27
-
28
- LATENT_MODES = [
29
- "Tortoise original (bad)",
30
- "average per 4.27s (broken on small files)",
31
- "average per voice file (broken on small files)",
32
- ]
33
-
34
- def main():
35
- conf = TortoiseConfig()
36
-
37
- with st.expander("Create New Voice", expanded=True):
38
- if "file_uploader_key" not in st.session_state:
39
- st.session_state["file_uploader_key"] = str(randint(1000, 100000000))
40
- st.session_state["text_input_key"] = str(randint(1000, 100000000))
41
-
42
- uploaded_files = st.file_uploader(
43
- "Upload Audio Samples for a New Voice",
44
- accept_multiple_files=True,
45
- type=["wav"],
46
- key=st.session_state["file_uploader_key"]
47
- )
48
-
49
- voice_name = st.text_input(
50
- "New Voice Name",
51
- help="Enter a name for your new voice.",
52
- value="",
53
- key=st.session_state["text_input_key"]
54
- )
55
-
56
- create_voice_button = st.button(
57
- "Create Voice",
58
- disabled = ((voice_name.strip() == "") | (len(uploaded_files) == 0))
59
- )
60
- if create_voice_button:
61
- st.write(st.session_state)
62
- with st.spinner(f"Creating new voice: {voice_name}"):
63
- new_voice_name = voice_name.strip().replace(" ", "_")
64
-
65
- voices_dir = f'./tortoise/voices/{new_voice_name}/'
66
- if os.path.exists(voices_dir):
67
- shutil.rmtree(voices_dir)
68
- os.makedirs(voices_dir)
69
-
70
- for index, uploaded_file in enumerate(uploaded_files):
71
- bytes_data = uploaded_file.read()
72
- with open(f"{voices_dir}voice_sample{index}.wav", "wb") as wav_file:
73
- wav_file.write(bytes_data)
74
-
75
- st.session_state["text_input_key"] = str(randint(1000, 100000000))
76
- st.session_state["file_uploader_key"] = str(randint(1000, 100000000))
77
- st.experimental_rerun()
78
-
79
- text = st.text_area(
80
- "Text",
81
- help="Text to speak.",
82
- value="The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them.",
83
- )
84
-
85
- voices = [v for v in os.listdir("tortoise/voices") if v != "cond_latent_example"]
86
-
87
- voice = st.selectbox(
88
- "Voice",
89
- voices,
90
- help="Selects the voice to use for generation. See options in voices/ directory (and add your own!) "
91
- "Use the & character to join two voices together. Use a comma to perform inference on multiple voices.",
92
- index=0,
93
- )
94
- preset = st.selectbox(
95
- "Preset",
96
- (
97
- "single_sample",
98
- "ultra_fast",
99
- "very_fast",
100
- "ultra_fast_old",
101
- "fast",
102
- "standard",
103
- "high_quality",
104
- ),
105
- help="Which voice preset to use.",
106
- index=1,
107
- )
108
- with st.expander("Advanced"):
109
- col1, col2 = st.columns(2)
110
- with col1:
111
- """#### Model parameters"""
112
- candidates = st.number_input(
113
- "Candidates",
114
- help="How many output candidates to produce per-voice.",
115
- value=1,
116
- )
117
- latent_averaging_mode = st.radio(
118
- "Latent averaging mode",
119
- LATENT_MODES,
120
- help="How voice samples should be averaged together.",
121
- index=0,
122
- )
123
- sampler = st.radio(
124
- "Sampler",
125
- #SAMPLERS,
126
- ["dpm++2m", "p", "ddim"],
127
- help="Diffusion sampler. Note that dpm++2m is experimental and typically requires more steps.",
128
- index=1,
129
- )
130
- steps = st.number_input(
131
- "Steps",
132
- help="Override the steps used for diffusion (default depends on preset)",
133
- value=10,
134
- )
135
- seed = st.number_input(
136
- "Seed",
137
- help="Random seed which can be used to reproduce results.",
138
- value=-1,
139
- )
140
- if seed == -1:
141
- seed = None
142
- voice_fixer = st.checkbox(
143
- "Voice fixer",
144
- help="Use `voicefixer` to improve audio quality. This is a post-processing step which can be applied to any output.",
145
- value=True,
146
- )
147
- """#### Directories"""
148
- output_path = st.text_input(
149
- "Output Path", help="Where to store outputs.", value="results/"
150
- )
151
-
152
- with col2:
153
- """#### Optimizations"""
154
- high_vram = not st.checkbox(
155
- "Low VRAM",
156
- help="Re-enable default offloading behaviour of tortoise",
157
- value=True,
158
- )
159
- half = st.checkbox(
160
- "Half-Precision",
161
- help="Enable autocast to half precision for autoregressive model",
162
- value=False,
163
- )
164
- kv_cache = st.checkbox(
165
- "Key-Value Cache",
166
- help="Enable kv_cache usage, leading to drastic speedups but worse memory usage",
167
- value=True,
168
- )
169
- cond_free = st.checkbox(
170
- "Conditioning Free",
171
- help="Force conditioning free diffusion",
172
- value=True,
173
- )
174
- no_cond_free = st.checkbox(
175
- "Force Not Conditioning Free",
176
- help="Force disable conditioning free diffusion",
177
- value=False,
178
- )
179
-
180
- """#### Text Splitting"""
181
- min_chars_to_split = st.number_input(
182
- "Min Chars to Split",
183
- help="Minimum number of characters to split text on",
184
- min_value=50,
185
- value=200,
186
- step=1,
187
- )
188
-
189
- """#### Debug"""
190
- produce_debug_state = st.checkbox(
191
- "Produce Debug State",
192
- help="Whether or not to produce debug_state.pth, which can aid in reproducing problems. Defaults to true.",
193
- value=True,
194
- )
195
-
196
- ar_checkpoint = "."
197
- diff_checkpoint = "."
198
- if st.button("Update Basic Settings"):
199
- conf.update(
200
- EXTRA_VOICES_DIR=extra_voices_dir,
201
- LOW_VRAM=not high_vram,
202
- AR_CHECKPOINT=ar_checkpoint,
203
- DIFF_CHECKPOINT=diff_checkpoint,
204
- )
205
-
206
- ar_checkpoint = None
207
- diff_checkpoint = None
208
- tts = load_model(MODELS_DIR, high_vram, kv_cache, ar_checkpoint, diff_checkpoint)
209
-
210
- if st.button("Start"):
211
- assert latent_averaging_mode
212
- assert preset
213
- assert voice
214
-
215
- def show_generation(fp, filename: str):
216
- """
217
- audio_buffer = BytesIO()
218
- save_gen_with_voicefix(g, audio_buffer, squeeze=False)
219
- torchaudio.save(audio_buffer, g, 24000, format='wav')
220
- """
221
- st.audio(str(fp), format="audio/wav")
222
- st.download_button(
223
- "Download sample",
224
- str(fp),
225
- file_name=filename, # this doesn't actually seem to work lol
226
- )
227
-
228
- with st.spinner(
229
- f"Generating {candidates} candidates for voice {voice} (seed={seed}). You can see progress in the terminal"
230
- ):
231
- os.makedirs(output_path, exist_ok=True)
232
-
233
- selected_voices = voice.split(",")
234
- for k, selected_voice in enumerate(selected_voices):
235
- if "&" in selected_voice:
236
- voice_sel = selected_voice.split("&")
237
- else:
238
- voice_sel = [selected_voice]
239
- voice_samples, conditioning_latents = load_voice_conditionings(
240
- voice_sel, []
241
- )
242
-
243
- voice_path = Path(os.path.join(output_path, selected_voice))
244
-
245
- with timeit(
246
- f"Generating {candidates} candidates for voice {selected_voice} (seed={seed})"
247
- ):
248
- nullable_kwargs = {
249
- k: v
250
- for k, v in zip(
251
- ["sampler", "diffusion_iterations", "cond_free"],
252
- [sampler, steps, cond_free],
253
- )
254
- if v is not None
255
- }
256
-
257
- def call_tts(text: str):
258
- return tts.tts_with_preset(
259
- text,
260
- k=candidates,
261
- voice_samples=voice_samples,
262
- conditioning_latents=conditioning_latents,
263
- preset=preset,
264
- use_deterministic_seed=seed,
265
- return_deterministic_state=True,
266
- cvvp_amount=0.0,
267
- half=half,
268
- latent_averaging_mode=LATENT_MODES.index(
269
- latent_averaging_mode
270
- ),
271
- **nullable_kwargs,
272
- )
273
-
274
- if len(text) < min_chars_to_split:
275
- filepaths = run_and_save_tts(
276
- call_tts,
277
- text,
278
- voice_path,
279
- return_deterministic_state=True,
280
- return_filepaths=True,
281
- voicefixer=voice_fixer,
282
- )
283
- for i, fp in enumerate(filepaths):
284
- show_generation(fp, f"{selected_voice}-text-{i}.wav")
285
- else:
286
- desired_length = int(min_chars_to_split)
287
- texts = split_and_recombine_text(
288
- text, desired_length, desired_length + 100
289
- )
290
- filepaths = infer_on_texts(
291
- call_tts,
292
- texts,
293
- voice_path,
294
- return_deterministic_state=True,
295
- return_filepaths=True,
296
- lines_to_regen=set(range(len(texts))),
297
- voicefixer=voice_fixer,
298
- )
299
- for i, fp in enumerate(filepaths):
300
- show_generation(fp, f"{selected_voice}-text-{i}.wav")
301
- if produce_debug_state:
302
- """Debug states can be found in the output directory"""
303
-
304
-
305
- if __name__ == "__main__":
306
- main()