ggoknar commited on
Commit
a38b58d
1 Parent(s): da4b074

stream voice with combined wav at end, optional direct stream

Browse files
Files changed (1) hide show
  1. app.py +454 -0
app.py CHANGED
@@ -5,6 +5,8 @@ import os
5
  # By using XTTS you agree to CPML license https://coqui.ai/cpml
6
  os.environ["COQUI_TOS_AGREED"] = "1"
7
 
 
 
8
  import gradio as gr
9
  import numpy as np
10
  import torch
@@ -32,6 +34,9 @@ from TTS.utils.generic_utils import get_user_data_dir
32
  # Could not make play audio next work seemlesly on current Gradio with autoplay so this is a workaround
33
  AUDIO_WAIT_MODIFIER = float(os.environ.get("AUDIO_WAIT_MODIFIER", 1))
34
 
 
 
 
35
  # This will trigger downloading model
36
  print("Downloading if not downloaded Coqui XTTS V1")
37
  tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1")
@@ -106,3 +111,452 @@ text_client = InferenceClient(
106
  "mistralai/Mistral-7B-Instruct-v0.1",
107
  timeout=WHISPER_TIMEOUT,
108
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  # By using XTTS you agree to CPML license https://coqui.ai/cpml
6
  os.environ["COQUI_TOS_AGREED"] = "1"
7
 
8
+ from scipy.io.wavfile import write
9
+ from pydub import AudioSegment
10
  import gradio as gr
11
  import numpy as np
12
  import torch
 
34
  # Could not make play audio next work seemlesly on current Gradio with autoplay so this is a workaround
35
  AUDIO_WAIT_MODIFIER = float(os.environ.get("AUDIO_WAIT_MODIFIER", 1))
36
 
37
+ # if set will try to stream audio while receveng audio chunks, beware that recreating audio each time produces artifacts
38
+ DIRECT_STREAM = int(os.environ.get("DIRECT_STREAM", 0))
39
+
40
  # This will trigger downloading model
41
  print("Downloading if not downloaded Coqui XTTS V1")
42
  tts = TTS("tts_models/multilingual/multi-dataset/xtts_v1")
 
111
  "mistralai/Mistral-7B-Instruct-v0.1",
112
  timeout=WHISPER_TIMEOUT,
113
  )
114
+
115
+
116
+ ###### COQUI TTS FUNCTIONS ######
117
+ def get_latents(speaker_wav):
118
+ # create as function as we can populate here with voice cleanup/filtering
119
+ (
120
+ gpt_cond_latent,
121
+ diffusion_conditioning,
122
+ speaker_embedding,
123
+ ) = model.get_conditioning_latents(audio_path=speaker_wav)
124
+ return gpt_cond_latent, diffusion_conditioning, speaker_embedding
125
+
126
+
127
+ def format_prompt(message, history):
128
+ prompt = (
129
+ "<s>[INST]"
130
+ + system_message
131
+ + "[/INST] I understand, I am a Mistral chatbot with speech by Coqui team.</s>"
132
+ )
133
+ for user_prompt, bot_response in history:
134
+ prompt += f"[INST] {user_prompt} [/INST]"
135
+ prompt += f" {bot_response}</s> "
136
+ prompt += f"[INST] {message} [/INST]"
137
+ return prompt
138
+
139
+
140
+ def generate(
141
+ prompt,
142
+ history,
143
+ temperature=0.9,
144
+ max_new_tokens=256,
145
+ top_p=0.95,
146
+ repetition_penalty=1.0,
147
+ ):
148
+ temperature = float(temperature)
149
+ if temperature < 1e-2:
150
+ temperature = 1e-2
151
+ top_p = float(top_p)
152
+
153
+ generate_kwargs = dict(
154
+ temperature=temperature,
155
+ max_new_tokens=max_new_tokens,
156
+ top_p=top_p,
157
+ repetition_penalty=repetition_penalty,
158
+ do_sample=True,
159
+ seed=42,
160
+ )
161
+
162
+ formatted_prompt = format_prompt(prompt, history)
163
+
164
+ try:
165
+ stream = text_client.text_generation(
166
+ formatted_prompt,
167
+ **generate_kwargs,
168
+ stream=True,
169
+ details=True,
170
+ return_full_text=False,
171
+ )
172
+ output = ""
173
+ for response in stream:
174
+ output += response.token.text
175
+ yield output
176
+
177
+ except Exception as e:
178
+ if "Too Many Requests" in str(e):
179
+ print("ERROR: Too many requests on mistral client")
180
+ gr.Warning("Unfortunately Mistral is unable to process")
181
+ output = "Unfortuanately I am not able to process your request now !"
182
+ else:
183
+ print("Unhandled Exception: ", str(e))
184
+ gr.Warning("Unfortunately Mistral is unable to process")
185
+ output = "I do not know what happened but I could not understand you ."
186
+
187
+ return output
188
+
189
+
190
+ def transcribe(wav_path):
191
+ try:
192
+ # get first element from whisper_jax and strip it to delete begin and end space
193
+ return whisper_client.predict(
194
+ wav_path, # str (filepath or URL to file) in 'inputs' Audio component
195
+ "transcribe", # str in 'Task' Radio component
196
+ False, # return_timestamps=False for whisper-jax https://gist.github.com/sanchit-gandhi/781dd7003c5b201bfe16d28634c8d4cf#file-whisper_jax_endpoint-py
197
+ api_name="/predict",
198
+ )[0].strip()
199
+ except:
200
+ gr.Warning("There was a problem with Whisper endpoint, telling a joke for you.")
201
+ return "There was a problem with my voice, tell me joke"
202
+
203
+
204
+ # Chatbot demo with multimodal input (text, markdown, LaTeX, code blocks, image, audio, & video). Plus shows support for streaming text.
205
+
206
+
207
+ def add_text(history, text):
208
+ history = [] if history is None else history
209
+ history = history + [(text, None)]
210
+ return history, gr.update(value="", interactive=False)
211
+
212
+
213
+ def add_file(history, file):
214
+ history = [] if history is None else history
215
+
216
+ try:
217
+ text = transcribe(file)
218
+ print("Transcribed text:", text)
219
+ except Exception as e:
220
+ print(str(e))
221
+ gr.Warning("There was an issue with transcription, please try writing for now")
222
+ # Apply a null text on error
223
+ text = "Transcription seems failed, please tell me a joke about chickens"
224
+
225
+ history = history + [(text, None)]
226
+ return history, gr.update(value="", interactive=False)
227
+
228
+
229
+ ##NOTE: not using this as it yields a chacter each time while we need to feed history to TTS
230
+ def bot(history, system_prompt=""):
231
+ history = [] if history is None else history
232
+
233
+ if system_prompt == "":
234
+ system_prompt = system_message
235
+
236
+ history[-1][1] = ""
237
+ for character in generate(history[-1][0], history[:-1]):
238
+ history[-1][1] = character
239
+ yield history
240
+
241
+
242
+ def get_latents(speaker_wav):
243
+ # Generate speaker embedding and latents for TTS
244
+ (
245
+ gpt_cond_latent,
246
+ diffusion_conditioning,
247
+ speaker_embedding,
248
+ ) = model.get_conditioning_latents(audio_path=speaker_wav)
249
+ return gpt_cond_latent, diffusion_conditioning, speaker_embedding
250
+
251
+
252
+ latent_map = {}
253
+ latent_map["Female_Voice"] = get_latents("examples/female.wav")
254
+
255
+
256
+ def get_voice(prompt, language, latent_tuple, suffix="0"):
257
+ gpt_cond_latent, diffusion_conditioning, speaker_embedding = latent_tuple
258
+ # Direct version
259
+ t0 = time.time()
260
+ out = model.inference(
261
+ prompt, language, gpt_cond_latent, speaker_embedding, diffusion_conditioning
262
+ )
263
+ inference_time = time.time() - t0
264
+ print(f"I: Time to generate audio: {round(inference_time*1000)} milliseconds")
265
+ real_time_factor = (time.time() - t0) / out["wav"].shape[-1] * 24000
266
+ print(f"Real-time factor (RTF): {real_time_factor}")
267
+ wav_filename = f"output_{suffix}.wav"
268
+ torchaudio.save(wav_filename, torch.tensor(out["wav"]).unsqueeze(0), 24000)
269
+ return wav_filename
270
+
271
+
272
+ def wave_header_chunk(frame_input=b"", channels=1, sample_width=2, sample_rate=24000):
273
+ # This will create a wave header then append the frame input
274
+ # It should be first on a streaming wav file
275
+ # Other frames better should not have it (else you will hear some artifacts each chunk start)
276
+ wav_buf = io.BytesIO()
277
+ with wave.open(wav_buf, "wb") as vfout:
278
+ vfout.setnchannels(channels)
279
+ vfout.setsampwidth(sample_width)
280
+ vfout.setframerate(sample_rate)
281
+ vfout.writeframes(frame_input)
282
+
283
+ wav_buf.seek(0)
284
+ return wav_buf.read()
285
+
286
+
287
+ def get_voice_streaming(prompt, language, latent_tuple, suffix="0"):
288
+ gpt_cond_latent, diffusion_conditioning, speaker_embedding = latent_tuple
289
+ try:
290
+ t0 = time.time()
291
+ chunks = model.inference_stream(
292
+ prompt,
293
+ language,
294
+ gpt_cond_latent,
295
+ speaker_embedding,
296
+ )
297
+
298
+ first_chunk = True
299
+ for i, chunk in enumerate(chunks):
300
+ if first_chunk:
301
+ first_chunk_time = time.time() - t0
302
+ metrics_text = f"Latency to first audio chunk: {round(first_chunk_time*1000)} milliseconds\n"
303
+ first_chunk = False
304
+ print(f"Received chunk {i} of audio length {chunk.shape[-1]}")
305
+
306
+ # In case output is required to be multiple voice files
307
+ # out_file = f'{char}_{i}.wav'
308
+ # write(out_file, 24000, chunk.detach().cpu().numpy().squeeze())
309
+ # audio = AudioSegment.from_file(out_file)
310
+ # audio.export(out_file, format='wav')
311
+ # return out_file
312
+ # directly return chunk as bytes for streaming
313
+ chunk = chunk.detach().cpu().numpy().squeeze()
314
+ chunk = (chunk * 32767).astype(np.int16)
315
+
316
+ yield chunk.tobytes()
317
+
318
+ except RuntimeError as e:
319
+ if "device-side assert" in str(e):
320
+ # cannot do anything on cuda device side error, need tor estart
321
+ print(
322
+ f"Exit due to: Unrecoverable exception caused by prompt:{sentence}",
323
+ flush=True,
324
+ )
325
+ gr.Warning("Unhandled Exception encounter, please retry in a minute")
326
+ print("Cuda device-assert Runtime encountered need restart")
327
+
328
+ # HF Space specific.. This error is unrecoverable need to restart space
329
+ api.restart_space(repo_id=repo_id)
330
+ else:
331
+ print("RuntimeError: non device-side assert error:", str(e))
332
+ gr.Warning("Unhandled Exception encounter, please retry in a minute")
333
+ return None
334
+ return None
335
+ except:
336
+ return None
337
+
338
+
339
+ def get_sentence(history, system_prompt=""):
340
+ history = [] if history is None else history
341
+
342
+ if system_prompt == "":
343
+ system_prompt = system_message
344
+
345
+ history[-1][1] = ""
346
+
347
+ mistral_start = time.time()
348
+ print("Mistral start")
349
+ sentence_list = []
350
+ sentence_hash_list = []
351
+
352
+ text_to_generate = ""
353
+ for character in generate(history[-1][0], history[:-1]):
354
+ history[-1][1] = character
355
+ # It is coming word by word
356
+
357
+ text_to_generate = nltk.sent_tokenize(history[-1][1].replace("\n", " ").strip())
358
+
359
+ if len(text_to_generate) > 1:
360
+ dif = len(text_to_generate) - len(sentence_list)
361
+
362
+ if dif == 1 and len(sentence_list) != 0:
363
+ continue
364
+
365
+ sentence = text_to_generate[len(sentence_list)]
366
+ # This is expensive replace with hashing!
367
+ sentence_hash = hash(sentence)
368
+
369
+ if sentence_hash not in sentence_hash_list:
370
+ sentence_hash_list.append(sentence_hash)
371
+ sentence_list.append(sentence)
372
+ print("New Sentence: ", sentence)
373
+ yield (sentence, history)
374
+
375
+ # return that final sentence token
376
+ # TODO need a counter that one may be replica as before
377
+ last_sentence = nltk.sent_tokenize(history[-1][1].replace("\n", " ").strip())[-1]
378
+ sentence_hash = hash(last_sentence)
379
+ if sentence_hash not in sentence_hash_list:
380
+ sentence_hash_list.append(sentence_hash)
381
+ sentence_list.append(last_sentence)
382
+ print("New Sentence: ", last_sentence)
383
+
384
+ yield (last_sentence, history)
385
+
386
+
387
+ def generate_speech(history):
388
+ language = "en"
389
+
390
+ wav_bytestream = b""
391
+ for sentence, history in get_sentence(history):
392
+ print(sentence)
393
+ # Sometimes prompt </s> coming on output remove it
394
+ sentence = sentence.replace("</s>", "")
395
+ # A fast fix for last chacter, may produce weird sounds if it is with text
396
+ if sentence[-1] in ["!", "?", ".", ","]:
397
+ # just add a space
398
+ sentence = sentence[:-1] + " " + sentence[-1]
399
+ print("Sentence for speech:", sentence)
400
+
401
+ try:
402
+ # generate speech using precomputed latents
403
+ # This is not streaming but it will be fast
404
+ # wav = get_voice(sentence,language, latent_map["Female_Voice"], suffix=len(wav_list))
405
+ audio_stream = get_voice_streaming(
406
+ sentence, language, latent_map["Female_Voice"]
407
+ )
408
+ wav_chunks = wave_header_chunk()
409
+ frame_length = 0
410
+ for chunk in audio_stream:
411
+ try:
412
+ wav_bytestream += chunk
413
+ if DIRECT_STREAM:
414
+ yield (
415
+ gr.Audio.update(
416
+ value=wave_header_chunk() + chunk, autoplay=True
417
+ ),
418
+ history,
419
+ )
420
+ wait_time = len(chunk) / 2 / 24000
421
+ wait_time = AUDIO_WAIT_MODIFIER * wait_time
422
+ print("Sleeping till chunk end")
423
+ time.sleep(wait_time)
424
+
425
+ else:
426
+ wav_chunks += chunk
427
+ frame_length += len(chunk)
428
+ except:
429
+ # hack to continue on playing. sometimes last chunk is empty , will be fixed on next TTS
430
+ continue
431
+
432
+ if not DIRECT_STREAM:
433
+ yield (gr.Audio.update(value=wav_chunks, autoplay=True), history)
434
+ # Streaming wait time calculation
435
+ # audio_length = frame_length / sample_width/ frame_rate
436
+ wait_time = frame_length / 2 / 24000
437
+
438
+ # for non streaming
439
+ # wait_time= librosa.get_duration(path=wav)
440
+
441
+ wait_time = AUDIO_WAIT_MODIFIER * wait_time
442
+ print("Sleeping till audio end")
443
+ time.sleep(wait_time)
444
+
445
+ except RuntimeError as e:
446
+ if "device-side assert" in str(e):
447
+ # cannot do anything on cuda device side error, need tor estart
448
+ print(
449
+ f"Exit due to: Unrecoverable exception caused by prompt:{sentence}",
450
+ flush=True,
451
+ )
452
+ gr.Warning("Unhandled Exception encounter, please retry in a minute")
453
+ print("Cuda device-assert Runtime encountered need restart")
454
+
455
+ # HF Space specific.. This error is unrecoverable need to restart space
456
+ api.restart_space(repo_id=repo_id)
457
+ else:
458
+ print("RuntimeError: non device-side assert error:", str(e))
459
+ raise e
460
+
461
+ # Spoken on autoplay everysencen now produce a concataned one at the one
462
+ # requires pip install ffmpeg-python
463
+
464
+ # files_to_concat= [ffmpeg.input(w) for w in wav_list]
465
+ # combined_file_name="combined.wav"
466
+ # ffmpeg.concat(*files_to_concat,v=0, a=1).output(combined_file_name).run(overwrite_output=True)
467
+ # final_audio.update(value=combined_file_name, visible=True)
468
+ # yield (combined_file_name, history
469
+
470
+ wav_bytestream = wave_header_chunk() + wav_bytestream
471
+ time.sleep(0.3)
472
+ yield (gr.Audio.update(value=None, autoplay=False), history)
473
+ yield (gr.Audio.update(value=wav_bytestream, autoplay=False), history)
474
+
475
+
476
+ css = """
477
+ .bot .chatbot p {
478
+ overflow: hidden; /* Ensures the content is not revealed until the animation */
479
+ //border-right: .15em solid orange; /* The typwriter cursor */
480
+ white-space: nowrap; /* Keeps the content on a single line */
481
+ margin: 0 auto; /* Gives that scrolling effect as the typing happens */
482
+ letter-spacing: .15em; /* Adjust as needed */
483
+ animation:
484
+ typing 3.5s steps(40, end);
485
+ blink-caret .75s step-end infinite;
486
+ }
487
+
488
+ /* The typing effect */
489
+ @keyframes typing {
490
+ from { width: 0 }
491
+ to { width: 100% }
492
+ }
493
+
494
+ /* The typewriter cursor effect */
495
+ @keyframes blink-caret {
496
+ from, to { border-color: transparent }
497
+ 50% { border-color: orange; }
498
+ }
499
+ """
500
+
501
+ with gr.Blocks(title=title) as demo:
502
+ gr.Markdown(DESCRIPTION)
503
+
504
+ chatbot = gr.Chatbot(
505
+ [],
506
+ elem_id="chatbot",
507
+ avatar_images=("examples/lama.jpeg", "examples/lama2.jpeg"),
508
+ bubble_full_width=False,
509
+ )
510
+
511
+ with gr.Row():
512
+ txt = gr.Textbox(
513
+ scale=3,
514
+ show_label=False,
515
+ placeholder="Enter text and press enter, or speak to your microphone",
516
+ container=False,
517
+ )
518
+ txt_btn = gr.Button(value="Submit text", scale=1)
519
+ btn = gr.Audio(source="microphone", type="filepath", scale=4)
520
+
521
+ with gr.Row():
522
+ audio = gr.Audio(
523
+ label="Generated audio response",
524
+ streaming=False,
525
+ autoplay=False,
526
+ interactive=True,
527
+ show_label=True,
528
+ )
529
+ # TODO add a second audio that plays whole sentences (for mobile especially)
530
+ # final_audio = gr.Audio(label="Final audio response", streaming=False, autoplay=False, interactive=False,show_label=True, visible=False)
531
+
532
+ clear_btn = gr.ClearButton([chatbot, audio])
533
+
534
+ txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
535
+ generate_speech, chatbot, [audio, chatbot]
536
+ )
537
+
538
+ txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
539
+
540
+ txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
541
+ generate_speech, chatbot, [audio, chatbot]
542
+ )
543
+
544
+ txt_msg.then(lambda: gr.update(interactive=True), None, [txt], queue=False)
545
+
546
+ file_msg = btn.stop_recording(
547
+ add_file, [chatbot, btn], [chatbot, txt], queue=False
548
+ ).then(generate_speech, chatbot, [audio, chatbot])
549
+
550
+ gr.Markdown(
551
+ """
552
+ This Space demonstrates how to speak to a chatbot, based solely on open-source models.
553
+ It relies on 3 models:
554
+ 1. [Whisper-large-v2](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax) as an ASR model, to transcribe recorded audio to text. It is called through a [gradio client](https://www.gradio.app/docs/client).
555
+ 2. [Mistral-7b-instruct](https://huggingface.co/spaces/osanseviero/mistral-super-fast) as the chat model, the actual chat model. It is called from [huggingface_hub](https://huggingface.co/docs/huggingface_hub/guides/inference).
556
+ 3. [Coqui's XTTS](https://huggingface.co/spaces/coqui/xtts) as a TTS model, to generate the chatbot answers. This time, the model is hosted locally.
557
+
558
+ Note:
559
+ - By using this demo you agree to the terms of the Coqui Public Model License at https://coqui.ai/cpml"""
560
+ )
561
+ demo.queue()
562
+ demo.launch(debug=True)