sagar007 commited on
Commit
1f0b302
·
verified ·
1 Parent(s): 1b825cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -18
app.py CHANGED
@@ -81,7 +81,7 @@ async def generate_speech(text, tts_model, tts_tokenizer):
81
 
82
  # Helper functions
83
  @spaces.GPU(timeout=300)
84
- async def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20, use_tts=True):
85
  try:
86
  conversation = [{"role": "system", "content": system_prompt}]
87
  for prompt, answer in history:
@@ -112,22 +112,14 @@ async def stream_text_chat(message, history, system_prompt, temperature=0.8, max
112
 
113
  buffer = ""
114
  audio_buffer = np.array([])
115
- tts_future = None
116
 
117
  for new_text in streamer:
118
  buffer += new_text
119
-
120
- if use_tts and len(buffer) > 50: # Start TTS generation when buffer has enough content
121
- if tts_future is None:
122
- tts_future = asyncio.get_event_loop().run_in_executor(
123
- executor, generate_speech, buffer, tts_model, tts_tokenizer
124
- )
125
-
126
  yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
127
 
128
- # Wait for TTS to complete if it's still running
129
- if use_tts and tts_future is not None:
130
- audio_buffer = await tts_future
131
 
132
  # Final yield with complete text and audio
133
  yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
@@ -136,6 +128,16 @@ async def stream_text_chat(message, history, system_prompt, temperature=0.8, max
136
  print(f"An error occurred: {str(e)}")
137
  yield history + [[message, f"An error occurred: {str(e)}"]], None
138
 
 
 
 
 
 
 
 
 
 
 
139
  @spaces.GPU(timeout=300) # Increase timeout to 5 minutes
140
  def process_vision_query(image, text_input):
141
  try:
@@ -210,7 +212,7 @@ custom_suggestions = """
210
  </div>
211
  """
212
 
213
- # Update the Gradio interface
214
  with gr.Blocks(css=custom_css, theme=gr.themes.Base().set(
215
  body_background_fill="#0b0f19",
216
  body_text_color="#e2e8f0",
@@ -221,7 +223,6 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Base().set(
221
  block_label_text_color="#94a3b8",
222
  )) as demo:
223
  gr.HTML(custom_header)
224
- gr.HTML(custom_suggestions)
225
 
226
  with gr.Tab("Text Model (Phi-3.5-mini)"):
227
  chatbot = gr.Chatbot(height=400)
@@ -238,8 +239,13 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Base().set(
238
  submit_btn = gr.Button("Submit", variant="primary")
239
  clear_btn = gr.Button("Clear Chat", variant="secondary")
240
 
241
- submit_btn.click(stream_text_chat, [msg, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k, use_tts], [chatbot, audio_output])
242
- clear_btn.click(lambda: None, None, chatbot, queue=False)
 
 
 
 
 
243
 
244
  with gr.Tab("Vision Model (Phi-3.5-vision)"):
245
  with gr.Row():
@@ -250,9 +256,9 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Base().set(
250
  with gr.Column(scale=1):
251
  vision_output_text = gr.Textbox(label="AI Analysis", lines=10)
252
 
253
- vision_submit_btn.click(process_vision_query, [vision_input_img, vision_text_input], [vision_output_text])
254
 
255
  gr.HTML("<footer>Powered by Phi 3.5 Multimodal AI</footer>")
256
 
257
  if __name__ == "__main__":
258
- demo.launch()
 
81
 
82
  # Helper functions
83
  @spaces.GPU(timeout=300)
84
+ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20, use_tts=True):
85
  try:
86
  conversation = [{"role": "system", "content": system_prompt}]
87
  for prompt, answer in history:
 
112
 
113
  buffer = ""
114
  audio_buffer = np.array([])
 
115
 
116
  for new_text in streamer:
117
  buffer += new_text
 
 
 
 
 
 
 
118
  yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
119
 
120
+ # Generate speech after text generation is complete
121
+ if use_tts:
122
+ audio_buffer = generate_speech_sync(buffer, tts_model, tts_tokenizer)
123
 
124
  # Final yield with complete text and audio
125
  yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
 
128
  print(f"An error occurred: {str(e)}")
129
  yield history + [[message, f"An error occurred: {str(e)}"]], None
130
 
131
+ def generate_speech_sync(text, tts_model, tts_tokenizer):
132
+ tts_input_ids = tts_tokenizer(text, return_tensors="pt").input_ids.to(device)
133
+ tts_description = "A clear and natural voice reads the text with moderate speed and expression."
134
+ tts_description_ids = tts_tokenizer(tts_description, return_tensors="pt").input_ids.to(device)
135
+
136
+ with torch.no_grad():
137
+ audio_generation = tts_model.generate(input_ids=tts_description_ids, prompt_input_ids=tts_input_ids)
138
+
139
+ return audio_generation.cpu().numpy().squeeze()
140
+
141
  @spaces.GPU(timeout=300) # Increase timeout to 5 minutes
142
  def process_vision_query(image, text_input):
143
  try:
 
212
  </div>
213
  """
214
 
215
+ # Gradio interface
216
  with gr.Blocks(css=custom_css, theme=gr.themes.Base().set(
217
  body_background_fill="#0b0f19",
218
  body_text_color="#e2e8f0",
 
223
  block_label_text_color="#94a3b8",
224
  )) as demo:
225
  gr.HTML(custom_header)
 
226
 
227
  with gr.Tab("Text Model (Phi-3.5-mini)"):
228
  chatbot = gr.Chatbot(height=400)
 
239
  submit_btn = gr.Button("Submit", variant="primary")
240
  clear_btn = gr.Button("Clear Chat", variant="secondary")
241
 
242
+ def clear_chat():
243
+ return None
244
+
245
+ submit_btn.click(stream_text_chat,
246
+ inputs=[msg, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k, use_tts],
247
+ outputs=[chatbot, audio_output])
248
+ clear_btn.click(clear_chat, outputs=chatbot)
249
 
250
  with gr.Tab("Vision Model (Phi-3.5-vision)"):
251
  with gr.Row():
 
256
  with gr.Column(scale=1):
257
  vision_output_text = gr.Textbox(label="AI Analysis", lines=10)
258
 
259
+ vision_submit_btn.click(process_vision_query, inputs=[vision_input_img, vision_text_input], outputs=vision_output_text)
260
 
261
  gr.HTML("<footer>Powered by Phi 3.5 Multimodal AI</footer>")
262
 
263
  if __name__ == "__main__":
264
+ demo.launch(share=True)