sagar007 commited on
Commit
7137466
·
verified ·
1 Parent(s): 6b1a045

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -24
app.py CHANGED
@@ -79,6 +79,8 @@ async def generate_speech(text, tts_model, tts_tokenizer):
79
 
80
  return audio_generation.cpu().numpy().squeeze()
81
 
 
 
82
  @spaces.GPU(timeout=300)
83
  def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20, use_tts=True):
84
  try:
@@ -102,7 +104,8 @@ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_t
102
  top_p=top_p,
103
  top_k=top_k,
104
  temperature=temperature,
105
- eos_token_id=[128001, 128008, 128009],
 
106
  streamer=streamer,
107
  )
108
 
@@ -110,35 +113,41 @@ def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_t
110
  thread.start()
111
 
112
  buffer = ""
113
- audio_buffer = np.array([0.0]) # Initialize with a single zero
114
-
115
  for new_text in streamer:
116
  buffer += new_text
117
- yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
118
-
119
- # Generate speech after text generation is complete
120
- if use_tts and buffer: # Only generate speech if there's text
121
- audio_buffer = generate_speech_sync(buffer, tts_model, tts_tokenizer)
122
- if audio_buffer.size == 0: # If audio_buffer is empty
123
- audio_buffer = np.array([0.0]) # Use a single zero instead
124
-
125
- # Final yield with complete text and audio
126
- yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
127
-
 
 
 
 
128
  except Exception as e:
129
  print(f"An error occurred: {str(e)}")
130
- yield history + [[message, f"An error occurred: {str(e)}"]], (tts_model.config.sampling_rate, np.array([0.0]))
131
 
132
  def generate_speech_sync(text, tts_model, tts_tokenizer):
133
- tts_input_ids = tts_tokenizer(text, return_tensors="pt").input_ids.to(device)
134
- tts_description = "A clear and natural voice reads the text with moderate speed and expression."
135
- tts_description_ids = tts_tokenizer(tts_description, return_tensors="pt").input_ids.to(device)
136
-
137
- with torch.no_grad():
138
- audio_generation = tts_model.generate(input_ids=tts_description_ids, prompt_input_ids=tts_input_ids)
139
-
140
- audio_buffer = audio_generation.cpu().numpy().squeeze()
141
- return audio_buffer if audio_buffer.size > 0 else np.array([0.0])
 
 
 
 
142
 
143
  @spaces.GPU(timeout=300) # Increase timeout to 5 minutes
144
  def process_vision_query(image, text_input):
 
79
 
80
  return audio_generation.cpu().numpy().squeeze()
81
 
82
+ from gradio import Error as GradioError
83
+
84
  @spaces.GPU(timeout=300)
85
  def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20, use_tts=True):
86
  try:
 
104
  top_p=top_p,
105
  top_k=top_k,
106
  temperature=temperature,
107
+ eos_token_id=text_tokenizer.eos_token_id,
108
+ pad_token_id=text_tokenizer.pad_token_id,
109
  streamer=streamer,
110
  )
111
 
 
113
  thread.start()
114
 
115
  buffer = ""
 
 
116
  for new_text in streamer:
117
  buffer += new_text
118
+ yield history + [[message, buffer]], None # Yield None for audio initially
119
+
120
+ # Only attempt TTS if it's enabled and we have a response
121
+ if use_tts and buffer:
122
+ try:
123
+ audio = generate_speech_sync(buffer, tts_model, tts_tokenizer)
124
+ yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio)
125
+ except Exception as e:
126
+ print(f"TTS failed: {str(e)}")
127
+ yield history + [[message, buffer]], None
128
+ else:
129
+ yield history + [[message, buffer]], None
130
+
131
+ except GradioError:
132
+ yield history + [[message, "GPU task aborted. Please try again."]], None
133
  except Exception as e:
134
  print(f"An error occurred: {str(e)}")
135
+ yield history + [[message, f"An error occurred: {str(e)}"]], None
136
 
137
  def generate_speech_sync(text, tts_model, tts_tokenizer):
138
+ try:
139
+ tts_input_ids = tts_tokenizer(text, return_tensors="pt").input_ids.to(device)
140
+ tts_description = "A clear and natural voice reads the text with moderate speed and expression."
141
+ tts_description_ids = tts_tokenizer(tts_description, return_tensors="pt").input_ids.to(device)
142
+
143
+ with torch.no_grad():
144
+ audio_generation = tts_model.generate(input_ids=tts_description_ids, prompt_input_ids=tts_input_ids)
145
+
146
+ audio_buffer = audio_generation.cpu().numpy().squeeze()
147
+ return audio_buffer if audio_buffer.size > 0 else np.array([0.0])
148
+ except Exception as e:
149
+ print(f"Speech generation failed: {str(e)}")
150
+ return np.array([0.0])
151
 
152
  @spaces.GPU(timeout=300) # Increase timeout to 5 minutes
153
  def process_vision_query(image, text_input):