Jofthomas HF staff commited on
Commit
7e60591
·
1 Parent(s): bde6ecc

Create coqui.py

Browse files
Files changed (1) hide show
  1. TextGen/coqui.py +399 -0
TextGen/coqui.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import io, os, stat
3
+ import subprocess
4
+ import random
5
+ from zipfile import ZipFile
6
+ import uuid
7
+ import time
8
+ import torch
9
+ import torchaudio
10
+ import numpy as np
11
+
12
+ os.system("pip install gradio-client @ git+https://github.com/gradio-app/gradio@bed454c3d22cfacedc047eb3b0ba987b485ac3fd")
13
+ os.system("pip install git+https://github.com/gradio-app/gradio.git@5.0-dev")
14
+ #update gradio to faster streaming
15
+ #download for mecab
16
+ os.system('python -m unidic download')
17
+
18
+ # By using XTTS you agree to CPML license https://coqui.ai/cpml
19
+ os.environ["COQUI_TOS_AGREED"] = "1"
20
+
21
+ # langid is used to detect language for longer text
22
+ # Most users expect text to be their own language, there is checkbox to disable it
23
+ import langid
24
+ import base64
25
+ import csv
26
+ from io import StringIO
27
+ import datetime
28
+ import re
29
+
30
+ import gradio as gr
31
+ from scipy.io.wavfile import write
32
+ from pydub import AudioSegment
33
+
34
+ from TTS.api import TTS
35
+ from TTS.tts.configs.xtts_config import XttsConfig
36
+ from TTS.tts.models.xtts import Xtts
37
+ from TTS.utils.generic_utils import get_user_data_dir
38
+
39
+ HF_TOKEN = os.environ.get("HF_TOKEN")
40
+
41
+ from huggingface_hub import HfApi
42
+ os.system("pip install git+https://github.com/gradio-app/gradio.git@5.0-dev")
43
+ # will use api to restart space on a unrecoverable error
44
+ api = HfApi(token=HF_TOKEN)
45
+ repo_id = "coqui/xtts"
46
+
47
+ # Use never ffmpeg binary for Ubuntu20 to use denoising for microphone input
48
+ print("Export newer ffmpeg binary for denoise filter")
49
+ ZipFile("ffmpeg.zip").extractall()
50
+ print("Make ffmpeg binary executable")
51
+ st = os.stat("ffmpeg")
52
+ os.chmod("ffmpeg", st.st_mode | stat.S_IEXEC)
53
+
54
+ # This will trigger downloading model
55
+ print("Downloading if not downloaded Coqui XTTS V2")
56
+ from TTS.utils.manage import ModelManager
57
+
58
+ model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
59
+ ModelManager().download_model(model_name)
60
+ model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
61
+ print("XTTS downloaded")
62
+
63
+ config = XttsConfig()
64
+ config.load_json(os.path.join(model_path, "config.json"))
65
+
66
+ model = Xtts.init_from_config(config)
67
+ model.load_checkpoint(
68
+ config,
69
+ checkpoint_path=os.path.join(model_path, "model.pth"),
70
+ vocab_path=os.path.join(model_path, "vocab.json"),
71
+ eval=True,
72
+ use_deepspeed=True,
73
+ )
74
+ model.cuda()
75
+
76
+ # This is for debugging purposes only
77
+ DEVICE_ASSERT_DETECTED = 0
78
+ DEVICE_ASSERT_PROMPT = None
79
+ DEVICE_ASSERT_LANG = None
80
+
81
+ supported_languages = config.languages
82
+ def numpy_to_mp3(audio_array, sampling_rate):
83
+ # Normalize audio_array if it's floating-point
84
+ if np.issubdtype(audio_array.dtype, np.floating):
85
+ max_val = np.max(np.abs(audio_array))
86
+ audio_array = (audio_array / max_val) * 32767 # Normalize to 16-bit range
87
+ audio_array = audio_array.astype(np.int16)
88
+
89
+ # Create an audio segment from the numpy array
90
+ audio_segment = AudioSegment(
91
+ audio_array.tobytes(),
92
+ frame_rate=sampling_rate,
93
+ sample_width=audio_array.dtype.itemsize,
94
+ channels=1
95
+ )
96
+
97
+ # Export the audio segment to MP3 bytes - use a high bitrate to maximise quality
98
+ mp3_io = io.BytesIO()
99
+ audio_segment.export(mp3_io, format="mp3", bitrate="320k")
100
+
101
+ # Get the MP3 bytes
102
+ mp3_bytes = mp3_io.getvalue()
103
+ mp3_io.close()
104
+
105
+ return mp3_bytes
106
+
107
+ def predict(
108
+ prompt,
109
+ language,
110
+ audio_file_pth,
111
+ mic_file_path,
112
+ use_mic,
113
+ voice_cleanup,
114
+ no_lang_auto_detect,
115
+ agree,
116
+ ):
117
+ if agree == True:
118
+ if language not in supported_languages:
119
+ gr.Warning(
120
+ f"Language you put {language} in is not in is not in our Supported Languages, please choose from dropdown"
121
+ )
122
+
123
+ return (
124
+ None,
125
+ )
126
+
127
+ language_predicted = langid.classify(prompt)[
128
+ 0
129
+ ].strip() # strip need as there is space at end!
130
+
131
+ # tts expects chinese as zh-cn
132
+ if language_predicted == "zh":
133
+ # we use zh-cn
134
+ language_predicted = "zh-cn"
135
+
136
+ print(f"Detected language:{language_predicted}, Chosen language:{language}")
137
+
138
+ # After text character length 15 trigger language detection
139
+ if len(prompt) > 15:
140
+ # allow any language for short text as some may be common
141
+ # If user unchecks language autodetection it will not trigger
142
+ # You may remove this completely for own use
143
+ if language_predicted != language and not no_lang_auto_detect:
144
+ # Please duplicate and remove this check if you really want this
145
+ # Or auto-detector fails to identify language (which it can on pretty short text or mixed text)
146
+ gr.Warning(
147
+ f"It looks like your text isn’t the language you chose , if you’re sure the text is the same language you chose, please check disable language auto-detection checkbox"
148
+ )
149
+
150
+ return (
151
+ None,
152
+ )
153
+
154
+ if use_mic == True:
155
+ if mic_file_path is not None:
156
+ speaker_wav = mic_file_path
157
+ else:
158
+ gr.Warning(
159
+ "Please record your voice with Microphone, or uncheck Use Microphone to use reference audios"
160
+ )
161
+ return (
162
+ None,
163
+ )
164
+
165
+ else:
166
+ speaker_wav = audio_file_pth
167
+
168
+ # Filtering for microphone input, as it has BG noise, maybe silence in beginning and end
169
+ # This is fast filtering not perfect
170
+
171
+ # Apply all on demand
172
+ lowpassfilter = denoise = trim = loudness = True
173
+
174
+ if lowpassfilter:
175
+ lowpass_highpass = "lowpass=8000,highpass=75,"
176
+ else:
177
+ lowpass_highpass = ""
178
+
179
+ if trim:
180
+ # better to remove silence in beginning and end for microphone
181
+ trim_silence = "areverse,silenceremove=start_periods=1:start_silence=0:start_threshold=0.02,areverse,silenceremove=start_periods=1:start_silence=0:start_threshold=0.02,"
182
+ else:
183
+ trim_silence = ""
184
+
185
+ if voice_cleanup:
186
+ try:
187
+ out_filename = (
188
+ speaker_wav + str(uuid.uuid4()) + ".wav"
189
+ ) # ffmpeg to know output format
190
+
191
+ # we will use newer ffmpeg as that has afftn denoise filter
192
+ shell_command = f"./ffmpeg -y -i {speaker_wav} -af {lowpass_highpass}{trim_silence} {out_filename}".split(
193
+ " "
194
+ )
195
+
196
+ command_result = subprocess.run(
197
+ [item for item in shell_command],
198
+ capture_output=False,
199
+ text=True,
200
+ check=True,
201
+ )
202
+ speaker_wav = out_filename
203
+ print("Filtered microphone input")
204
+ except subprocess.CalledProcessError:
205
+ # There was an error - command exited with non-zero code
206
+ print("Error: failed filtering, use original microphone input")
207
+ else:
208
+ speaker_wav = speaker_wav
209
+
210
+ if len(prompt) < 2:
211
+ gr.Warning("Please give a longer prompt text")
212
+ return (
213
+ None,
214
+ )
215
+ if len(prompt) > 1000:
216
+ gr.Warning(
217
+ "Text length limited to 200 characters for this demo, please try shorter text. You can clone this space and edit code for your own usage"
218
+ )
219
+ return (
220
+ None,
221
+ )
222
+ global DEVICE_ASSERT_DETECTED
223
+ if DEVICE_ASSERT_DETECTED:
224
+ global DEVICE_ASSERT_PROMPT
225
+ global DEVICE_ASSERT_LANG
226
+ # It will likely never come here as we restart space on first unrecoverable error now
227
+ print(
228
+ f"Unrecoverable exception caused by language:{DEVICE_ASSERT_LANG} prompt:{DEVICE_ASSERT_PROMPT}"
229
+ )
230
+
231
+ # HF Space specific.. This error is unrecoverable need to restart space
232
+ space = api.get_space_runtime(repo_id=repo_id)
233
+ if space.stage != "BUILDING":
234
+ api.restart_space(repo_id=repo_id)
235
+ else:
236
+ print("TRIED TO RESTART but space is building")
237
+
238
+ try:
239
+ metrics_text = ""
240
+ t_latent = time.time()
241
+
242
+ # note diffusion_conditioning not used on hifigan (default mode), it will be empty but need to pass it to model.inference
243
+ try:
244
+ (
245
+ gpt_cond_latent,
246
+ speaker_embedding,
247
+ ) = model.get_conditioning_latents(audio_path=speaker_wav, gpt_cond_len=30, gpt_cond_chunk_len=4, max_ref_length=60)
248
+ except Exception as e:
249
+ print("Speaker encoding error", str(e))
250
+ gr.Warning(
251
+ "It appears something wrong with reference, did you unmute your microphone?"
252
+ )
253
+ return (
254
+ None,
255
+ )
256
+
257
+ latent_calculation_time = time.time() - t_latent
258
+ # metrics_text=f"Embedding calculation time: {latent_calculation_time:.2f} seconds\n"
259
+
260
+ # temporary comma fix
261
+ prompt = re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)", r"\1 \2\2", prompt)
262
+
263
+ wav_chunks = []
264
+ ## Direct mode
265
+ """
266
+ print("I: Generating new audio...")
267
+ t0 = time.time()
268
+ out = model.inference(
269
+ prompt,
270
+ language,
271
+ gpt_cond_latent,
272
+ speaker_embedding,
273
+ repetition_penalty=5.0,
274
+ temperature=0.75,
275
+ )
276
+ inference_time = time.time() - t0
277
+ print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
278
+ metrics_text+=f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
279
+ real_time_factor= (time.time() - t0) / out['wav'].shape[-1] * 24000
280
+ print(f"Real-time factor (RTF): {real_time_factor}")
281
+ metrics_text+=f"Real-time factor (RTF): {real_time_factor:.2f}\n"
282
+ torchaudio.save("output.wav", torch.tensor(out["wav"]).unsqueeze(0), 24000)
283
+ """
284
+ print("I: Generating new audio in streaming mode...")
285
+ t0 = time.time()
286
+ chunks = model.inference_stream(
287
+ prompt,
288
+ language,
289
+ gpt_cond_latent,
290
+ speaker_embedding,
291
+ repetition_penalty=7.0,
292
+ temperature=0.85,
293
+ )
294
+
295
+ first_chunk = True
296
+ for i, chunk in enumerate(chunks):
297
+ if first_chunk:
298
+ first_chunk_time = time.time() - t0
299
+ metrics_text += f"Latency to first audio chunk: {round(first_chunk_time*1000)} milliseconds\n"
300
+ first_chunk = False
301
+
302
+ # Convert chunk to numpy array and return it
303
+ chunk_np = chunk.cpu().numpy()
304
+ print('chunk',i)
305
+ yield (24000, chunk_np)
306
+ wav_chunks.append(chunk)
307
+
308
+ print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
309
+ inference_time = time.time() - t0
310
+ print(
311
+ f"I: Time to generate audio: {round(inference_time*1000)} milliseconds"
312
+ )
313
+ # metrics_text += (
314
+ # f"Time to generate audio: {round(inference_time*1000)} milliseconds\n"
315
+ #)
316
+
317
+ except RuntimeError as e:
318
+ if "device-side assert" in str(e):
319
+ # cannot do anything on cuda device side error, need tor estart
320
+ print(
321
+ f"Exit due to: Unrecoverable exception caused by language:{language} prompt:{prompt}",
322
+ flush=True,
323
+ )
324
+ gr.Warning("Unhandled Exception encounter, please retry in a minute")
325
+ print("Cuda device-assert Runtime encountered need restart")
326
+ if not DEVICE_ASSERT_DETECTED:
327
+ DEVICE_ASSERT_DETECTED = 1
328
+ DEVICE_ASSERT_PROMPT = prompt
329
+ DEVICE_ASSERT_LANG = language
330
+
331
+ # just before restarting save what caused the issue so we can handle it in future
332
+ # Uploading Error data only happens for unrecovarable error
333
+ error_time = datetime.datetime.now().strftime("%d-%m-%Y-%H:%M:%S")
334
+ error_data = [
335
+ error_time,
336
+ prompt,
337
+ language,
338
+ audio_file_pth,
339
+ mic_file_path,
340
+ use_mic,
341
+ voice_cleanup,
342
+ no_lang_auto_detect,
343
+ agree,
344
+ ]
345
+ error_data = [str(e) if type(e) != str else e for e in error_data]
346
+ print(error_data)
347
+ print(speaker_wav)
348
+ write_io = StringIO()
349
+ csv.writer(write_io).writerows([error_data])
350
+ csv_upload = write_io.getvalue().encode()
351
+
352
+ filename = error_time + "_" + str(uuid.uuid4()) + ".csv"
353
+ print("Writing error csv")
354
+ error_api = HfApi()
355
+ error_api.upload_file(
356
+ path_or_fileobj=csv_upload,
357
+ path_in_repo=filename,
358
+ repo_id="coqui/xtts-flagged-dataset",
359
+ repo_type="dataset",
360
+ )
361
+
362
+ # speaker_wav
363
+ print("Writing error reference audio")
364
+ speaker_filename = (
365
+ error_time + "_reference_" + str(uuid.uuid4()) + ".wav"
366
+ )
367
+ error_api = HfApi()
368
+ error_api.upload_file(
369
+ path_or_fileobj=speaker_wav,
370
+ path_in_repo=speaker_filename,
371
+ repo_id="coqui/xtts-flagged-dataset",
372
+ repo_type="dataset",
373
+ )
374
+
375
+ # HF Space specific.. This error is unrecoverable need to restart space
376
+ space = api.get_space_runtime(repo_id=repo_id)
377
+ if space.stage != "BUILDING":
378
+ api.restart_space(repo_id=repo_id)
379
+ else:
380
+ print("TRIED TO RESTART but space is building")
381
+
382
+ else:
383
+ if "Failed to decode" in str(e):
384
+ print("Speaker encoding error", str(e))
385
+ gr.Warning(
386
+ "It appears something wrong with reference, did you unmute your microphone?"
387
+ )
388
+ else:
389
+ print("RuntimeError: non device-side assert error:", str(e))
390
+ gr.Warning("Something unexpected happened please retry again.")
391
+ return (
392
+ None,
393
+ )
394
+
395
+ else:
396
+ gr.Warning("Please accept the Terms & Condition!")
397
+ return (
398
+ None,
399
+ )