ygauravyy commited on
Commit
17043fb
·
verified ·
1 Parent(s): cabb3d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -66
app.py CHANGED
@@ -14,13 +14,6 @@ from dotenv import load_dotenv
14
  # Load environment variables
15
  load_dotenv()
16
 
17
- # Global variables for preloaded resources
18
- en_base_speaker_tts = None
19
- zh_base_speaker_tts = None
20
- tone_color_converter = None
21
- target_se = None
22
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
-
24
  # Function to download and extract checkpoints
25
  def download_and_extract_checkpoints():
26
  zip_url = "https://huggingface.co/camenduru/OpenVoice/resolve/main/checkpoints_1226.zip"
@@ -39,52 +32,56 @@ def download_and_extract_checkpoints():
39
  os.remove(zip_path)
40
  print("Checkpoints are ready.")
41
 
42
- # Initialize models and resources
43
- def initialize_resources():
44
- global en_base_speaker_tts, zh_base_speaker_tts, tone_color_converter, target_se
45
- print("Initializing resources...")
46
-
47
- # Download and extract checkpoints
48
- download_and_extract_checkpoints()
49
-
50
- # Define paths to checkpoints
51
- en_ckpt_base = 'checkpoints/base_speakers/EN'
52
- zh_ckpt_base = 'checkpoints/base_speakers/ZH'
53
- ckpt_converter = 'checkpoints/converter'
54
-
55
- # Load TTS models
56
- en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
57
- en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
58
- zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
59
- zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
60
 
61
- # Load tone color converter
62
- tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
63
- tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
 
64
 
65
- # Load speaker embeddings
66
- en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
67
- zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
68
 
69
- # Extract speaker embedding from the default Mickey Mouse audio
70
- default_speaker_audio = "resources/output.wav"
71
- try:
72
- target_se, _ = se_extractor.get_se(
73
- default_speaker_audio,
74
- tone_color_converter,
75
- target_dir='processed',
76
- vad=True
77
- )
78
- print("Speaker embedding extracted successfully.")
79
- except Exception as e:
80
- raise RuntimeError(f"Failed to extract speaker embedding from {default_speaker_audio}: {str(e)}")
81
-
82
- initialize_resources()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Supported languages
85
  supported_languages = ['zh', 'en']
86
 
87
- # Predict function
88
  def predict(audio_file_pth, agree):
89
  text_hint = ''
90
  synthesized_audio_path = None
@@ -98,7 +95,7 @@ def predict(audio_file_pth, agree):
98
  if audio_file_pth is not None:
99
  speaker_wav = audio_file_pth
100
  else:
101
- text_hint += "[ERROR] Please provide an audio file.\n"
102
  return (text_hint, None)
103
 
104
  # Transcribe audio to text using OpenAI Whisper
@@ -124,66 +121,105 @@ def predict(audio_file_pth, agree):
124
  print(f"Detected language: {language_predicted}")
125
 
126
  if language_predicted not in supported_languages:
127
- text_hint += f"[ERROR] Unsupported language: {language_predicted}\n"
128
  return (text_hint, None)
129
 
130
- # Select TTS model
131
- tts_model = zh_base_speaker_tts if language_predicted == "zh" else en_base_speaker_tts
132
- language = 'Chinese' if language_predicted == "zh" else 'English'
 
 
 
 
 
 
133
 
 
134
  # Generate response using OpenAI GPT-4
135
  try:
136
  response = openai.chat.completions.create(
137
  model="gpt-4o-mini",
138
  messages=[
139
- {"role": "system", "content": "You are Mickey Mouse, a cheerful character who responds to children's queries."},
140
  {"role": "user", "content": input_text}
141
- ]
 
 
 
 
142
  )
143
- reply_text = response['choices'][0]['message']['content'].strip()
 
144
  print(f"GPT-4 Reply: {reply_text}")
145
  except Exception as e:
146
- text_hint += f"[ERROR] GPT-4 response failed: {str(e)}\n"
147
  return (text_hint, None)
148
 
149
  # Synthesize reply text to audio
150
  try:
151
  src_path = os.path.join(output_dir, 'tmp_reply.wav')
152
- tts_model.tts(reply_text, src_path, speaker='default', language=language)
 
 
153
 
154
  save_path = os.path.join(output_dir, 'output_reply.wav')
 
155
  tone_color_converter.convert(
156
- audio_src_path=src_path,
157
- src_se=target_se,
158
  tgt_se=target_se,
159
- output_path=save_path
 
160
  )
 
161
 
162
  text_hint += "Response generated successfully.\n"
163
  synthesized_audio_path = save_path
 
164
  except Exception as e:
165
- text_hint += f"[ERROR] Synthesis failed: {str(e)}\n"
166
  traceback.print_exc()
167
  return (text_hint, None)
168
 
169
  return (text_hint, synthesized_audio_path)
170
 
171
- # Gradio UI
172
  with gr.Blocks(analytics_enabled=False) as demo:
173
  gr.Markdown("# Mickey Mouse Voice Assistant")
174
 
175
  with gr.Row():
176
  with gr.Column():
177
- audio_input = gr.Audio(source="microphone", type="filepath", label="Record Your Voice")
178
- tos_checkbox = gr.Checkbox(label="Agree to Terms & Conditions", value=False)
 
 
 
 
 
 
 
 
 
179
  submit_button = gr.Button("Send")
180
 
181
  with gr.Column():
182
- info_output = gr.Textbox(label="Info", interactive=False, lines=4)
183
- audio_output = gr.Audio(label="Mickey's Response", interactive=False, autoplay=True)
 
 
 
 
 
 
 
 
184
 
185
- submit_button.click(predict, inputs=[audio_input, tos_checkbox], outputs=[info_output, audio_output])
 
 
 
 
186
 
 
187
  demo.queue()
188
  demo.launch(
189
  server_name="0.0.0.0",
 
14
  # Load environment variables
15
  load_dotenv()
16
 
 
 
 
 
 
 
 
17
  # Function to download and extract checkpoints
18
  def download_and_extract_checkpoints():
19
  zip_url = "https://huggingface.co/camenduru/OpenVoice/resolve/main/checkpoints_1226.zip"
 
32
  os.remove(zip_path)
33
  print("Checkpoints are ready.")
34
 
35
+ # Call the function to ensure checkpoints are available
36
+ download_and_extract_checkpoints()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # Initialize OpenAI API key
39
+ openai.api_key = os.getenv("OPENAI_API_KEY")
40
+ if not openai.api_key:
41
+ raise ValueError("Please set the OPENAI_API_KEY environment variable.")
42
 
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument("--share", action='store_true', default=False, help="make link public")
45
+ args = parser.parse_args()
46
 
47
+ # Define paths to checkpoints
48
+ en_ckpt_base = 'checkpoints/base_speakers/EN'
49
+ zh_ckpt_base = 'checkpoints/base_speakers/ZH'
50
+ ckpt_converter = 'checkpoints/converter'
51
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
52
+ output_dir = 'outputs'
53
+ os.makedirs(output_dir, exist_ok=True)
54
+
55
+ # Load TTS models
56
+ en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
57
+ en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
58
+ zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
59
+ zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
60
+
61
+ tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
62
+ tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
63
+
64
+ # Load speaker embeddings
65
+ en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
66
+ en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device)
67
+ zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
68
+
69
+ # Extract speaker embedding from the default Mickey Mouse audio
70
+ default_speaker_audio = "resources/output.wav"
71
+ try:
72
+ target_se, _ = se_extractor.get_se(
73
+ default_speaker_audio,
74
+ tone_color_converter,
75
+ target_dir='processed',
76
+ vad=True
77
+ )
78
+ print("Speaker embedding extracted successfully.")
79
+ except Exception as e:
80
+ raise RuntimeError(f"Failed to extract speaker embedding from {default_speaker_audio}: {str(e)}")
81
 
82
  # Supported languages
83
  supported_languages = ['zh', 'en']
84
 
 
85
  def predict(audio_file_pth, agree):
86
  text_hint = ''
87
  synthesized_audio_path = None
 
95
  if audio_file_pth is not None:
96
  speaker_wav = audio_file_pth
97
  else:
98
+ text_hint += "[ERROR] Please record your voice using the Microphone.\n"
99
  return (text_hint, None)
100
 
101
  # Transcribe audio to text using OpenAI Whisper
 
121
  print(f"Detected language: {language_predicted}")
122
 
123
  if language_predicted not in supported_languages:
124
+ text_hint += f"[ERROR] The detected language '{language_predicted}' is not supported. Supported languages are: {supported_languages}\n"
125
  return (text_hint, None)
126
 
127
+ # Select TTS model based on language
128
+ if language_predicted == "zh":
129
+ tts_model = zh_base_speaker_tts
130
+ language = 'Chinese'
131
+ speaker_style = 'default'
132
+ else:
133
+ tts_model = en_base_speaker_tts
134
+ language = 'English'
135
+ speaker_style = 'default'
136
 
137
+ # Generate response using OpenAI GPT-4
138
  # Generate response using OpenAI GPT-4
139
  try:
140
  response = openai.chat.completions.create(
141
  model="gpt-4o-mini",
142
  messages=[
143
+ {"role": "system", "content": "You are Mickey Mouse, a friendly and cheerful character who responds to children's queries in a simple and engaging manner. Please keep your response up to 200 characters."},
144
  {"role": "user", "content": input_text}
145
+ ],
146
+ max_tokens=200,
147
+ n=1,
148
+ stop=None,
149
+ temperature=0.7,
150
  )
151
+ # Correctly access the response content
152
+ reply_text = response.choices[0].message.content.strip()
153
  print(f"GPT-4 Reply: {reply_text}")
154
  except Exception as e:
155
+ text_hint += f"[ERROR] Failed to get response from OpenAI GPT-4: {str(e)}\n"
156
  return (text_hint, None)
157
 
158
  # Synthesize reply text to audio
159
  try:
160
  src_path = os.path.join(output_dir, 'tmp_reply.wav')
161
+
162
+ tts_model.tts(reply_text, src_path, speaker=speaker_style, language=language)
163
+ print(f"Audio synthesized and saved to {src_path}")
164
 
165
  save_path = os.path.join(output_dir, 'output_reply.wav')
166
+
167
  tone_color_converter.convert(
168
+ audio_src_path=src_path,
169
+ src_se=en_source_default_se if language == 'English' else zh_source_se,
170
  tgt_se=target_se,
171
+ output_path=save_path,
172
+ message="@MickeyMouse"
173
  )
174
+ print(f"Tone color conversion completed and saved to {save_path}")
175
 
176
  text_hint += "Response generated successfully.\n"
177
  synthesized_audio_path = save_path
178
+
179
  except Exception as e:
180
+ text_hint += f"[ERROR] Failed to synthesize audio: {str(e)}\n"
181
  traceback.print_exc()
182
  return (text_hint, None)
183
 
184
  return (text_hint, synthesized_audio_path)
185
 
 
186
  with gr.Blocks(analytics_enabled=False) as demo:
187
  gr.Markdown("# Mickey Mouse Voice Assistant")
188
 
189
  with gr.Row():
190
  with gr.Column():
191
+ audio_input = gr.Audio(
192
+ source="microphone",
193
+ type="filepath",
194
+ label="Record Your Voice",
195
+ info="Click the microphone button to record your voice."
196
+ )
197
+ tos_checkbox = gr.Checkbox(
198
+ label="Agree to Terms & Conditions",
199
+ value=False,
200
+ info="I agree to the terms of service."
201
+ )
202
  submit_button = gr.Button("Send")
203
 
204
  with gr.Column():
205
+ info_output = gr.Textbox(
206
+ label="Info",
207
+ interactive=False,
208
+ lines=4,
209
+ )
210
+ audio_output = gr.Audio(
211
+ label="Mickey's Response",
212
+ interactive=False,
213
+ autoplay=True,
214
+ )
215
 
216
+ submit_button.click(
217
+ predict,
218
+ inputs=[audio_input, tos_checkbox],
219
+ outputs=[info_output, audio_output]
220
+ )
221
 
222
+ # Launch the Gradio app
223
  demo.queue()
224
  demo.launch(
225
  server_name="0.0.0.0",