sagar007 commited on
Commit
1b825cc
·
verified ·
1 Parent(s): 78e7cbb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -17
app.py CHANGED
@@ -10,6 +10,11 @@ import spaces
10
  from parler_tts import ParlerTTSForConditionalGeneration
11
  import soundfile as sf
12
  import tempfile
 
 
 
 
 
13
 
14
  # Install flash-attention
15
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
@@ -63,9 +68,20 @@ vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_c
63
  tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
64
  tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
65
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Helper functions
67
- @spaces.GPU(timeout=300) # Increase timeout to 5 minutes
68
- def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20):
69
  try:
70
  conversation = [{"role": "system", "content": system_prompt}]
71
  for prompt, answer in history:
@@ -91,27 +107,31 @@ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_t
91
  streamer=streamer,
92
  )
93
 
94
- with torch.no_grad():
95
- thread = Thread(target=text_model.generate, kwargs=generate_kwargs)
96
- thread.start()
97
 
98
  buffer = ""
99
  audio_buffer = np.array([])
 
 
100
  for new_text in streamer:
101
  buffer += new_text
102
 
103
- # Generate speech for the new text
104
- tts_input_ids = tts_tokenizer(new_text, return_tensors="pt").input_ids.to(device)
105
- tts_description = "A clear and natural voice reads the text with moderate speed and expression."
106
- tts_description_ids = tts_tokenizer(tts_description, return_tensors="pt").input_ids.to(device)
107
-
108
- with torch.no_grad():
109
- audio_generation = tts_model.generate(input_ids=tts_description_ids, prompt_input_ids=tts_input_ids)
110
-
111
- new_audio = audio_generation.cpu().numpy().squeeze()
112
- audio_buffer = np.concatenate((audio_buffer, new_audio))
113
 
114
  yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
 
 
 
 
 
 
 
 
115
  except Exception as e:
116
  print(f"An error occurred: {str(e)}")
117
  yield history + [[message, f"An error occurred: {str(e)}"]], None
@@ -190,7 +210,7 @@ custom_suggestions = """
190
  </div>
191
  """
192
 
193
- # Gradio interface
194
  with gr.Blocks(css=custom_css, theme=gr.themes.Base().set(
195
  body_background_fill="#0b0f19",
196
  body_text_color="#e2e8f0",
@@ -213,11 +233,12 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Base().set(
213
  max_new_tokens = gr.Slider(minimum=128, maximum=8192, step=1, value=1024, label="Max new tokens")
214
  top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p")
215
  top_k = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k")
 
216
 
217
  submit_btn = gr.Button("Submit", variant="primary")
218
  clear_btn = gr.Button("Clear Chat", variant="secondary")
219
 
220
- submit_btn.click(stream_text_chat, [msg, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k], [chatbot, audio_output])
221
  clear_btn.click(lambda: None, None, chatbot, queue=False)
222
 
223
  with gr.Tab("Vision Model (Phi-3.5-vision)"):
 
10
  from parler_tts import ParlerTTSForConditionalGeneration
11
  import soundfile as sf
12
  import tempfile
13
+ import asyncio
14
+ from concurrent.futures import ThreadPoolExecutor
15
+
16
+ # Add this global variable after the imports
17
+ executor = ThreadPoolExecutor(max_workers=2)
18
 
19
  # Install flash-attention
20
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
68
  tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
69
  tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
70
 
71
+ # Add the generate_speech function here
72
+ async def generate_speech(text, tts_model, tts_tokenizer):
73
+ tts_input_ids = tts_tokenizer(text, return_tensors="pt").input_ids.to(device)
74
+ tts_description = "A clear and natural voice reads the text with moderate speed and expression."
75
+ tts_description_ids = tts_tokenizer(tts_description, return_tensors="pt").input_ids.to(device)
76
+
77
+ with torch.no_grad():
78
+ audio_generation = tts_model.generate(input_ids=tts_description_ids, prompt_input_ids=tts_input_ids)
79
+
80
+ return audio_generation.cpu().numpy().squeeze()
81
+
82
  # Helper functions
83
+ @spaces.GPU(timeout=300)
84
+ async def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20, use_tts=True):
85
  try:
86
  conversation = [{"role": "system", "content": system_prompt}]
87
  for prompt, answer in history:
 
107
  streamer=streamer,
108
  )
109
 
110
+ thread = Thread(target=text_model.generate, kwargs=generate_kwargs)
111
+ thread.start()
 
112
 
113
  buffer = ""
114
  audio_buffer = np.array([])
115
+ tts_future = None
116
+
117
  for new_text in streamer:
118
  buffer += new_text
119
 
120
+ if use_tts and len(buffer) > 50: # Start TTS generation when buffer has enough content
121
+ if tts_future is None:
122
+ tts_future = asyncio.get_event_loop().run_in_executor(
123
+ executor, generate_speech, buffer, tts_model, tts_tokenizer
124
+ )
 
 
 
 
 
125
 
126
  yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
127
+
128
+ # Wait for TTS to complete if it's still running
129
+ if use_tts and tts_future is not None:
130
+ audio_buffer = await tts_future
131
+
132
+ # Final yield with complete text and audio
133
+ yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
134
+
135
  except Exception as e:
136
  print(f"An error occurred: {str(e)}")
137
  yield history + [[message, f"An error occurred: {str(e)}"]], None
 
210
  </div>
211
  """
212
 
213
+ # Update the Gradio interface
214
  with gr.Blocks(css=custom_css, theme=gr.themes.Base().set(
215
  body_background_fill="#0b0f19",
216
  body_text_color="#e2e8f0",
 
233
  max_new_tokens = gr.Slider(minimum=128, maximum=8192, step=1, value=1024, label="Max new tokens")
234
  top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="top_p")
235
  top_k = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_k")
236
+ use_tts = gr.Checkbox(label="Enable Text-to-Speech", value=True)
237
 
238
  submit_btn = gr.Button("Submit", variant="primary")
239
  clear_btn = gr.Button("Clear Chat", variant="secondary")
240
 
241
+ submit_btn.click(stream_text_chat, [msg, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k, use_tts], [chatbot, audio_output])
242
  clear_btn.click(lambda: None, None, chatbot, queue=False)
243
 
244
  with gr.Tab("Vision Model (Phi-3.5-vision)"):