ygauravyy commited on
Commit
cabb3d1
·
verified ·
1 Parent(s): 39f7a33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -83
app.py CHANGED
@@ -10,15 +10,16 @@ from api import BaseSpeakerTTS, ToneColorConverter
10
  import langid
11
  import traceback
12
  from dotenv import load_dotenv
13
- from fastapi import FastAPI, UploadFile, Form
14
- from fastapi.responses import JSONResponse
15
- from gradio.routes import mount_gradio_app
16
 
17
  # Load environment variables
18
  load_dotenv()
19
 
20
- # Initialize FastAPI app
21
- app = FastAPI()
 
 
 
 
22
 
23
  # Function to download and extract checkpoints
24
  def download_and_extract_checkpoints():
@@ -38,53 +39,52 @@ def download_and_extract_checkpoints():
38
  os.remove(zip_path)
39
  print("Checkpoints are ready.")
40
 
41
- # Call the function to ensure checkpoints are available
42
- download_and_extract_checkpoints()
 
 
43
 
44
- # Initialize OpenAI API key
45
- openai.api_key = os.getenv("OPENAI_API_KEY")
46
- if not openai.api_key:
47
- raise ValueError("Please set the OPENAI_API_KEY environment variable.")
48
 
49
- # Define paths to checkpoints
50
- en_ckpt_base = 'checkpoints/base_speakers/EN'
51
- zh_ckpt_base = 'checkpoints/base_speakers/ZH'
52
- ckpt_converter = 'checkpoints/converter'
53
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
54
- output_dir = 'outputs'
55
- os.makedirs(output_dir, exist_ok=True)
56
-
57
- # Load TTS models
58
- en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
59
- en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
60
- zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
61
- zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
62
-
63
- tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
64
- tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
65
-
66
- # Load speaker embeddings
67
- en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
68
- en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device)
69
- zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
70
-
71
- # Extract speaker embedding from the default Mickey Mouse audio
72
- default_speaker_audio = "resources/output.wav"
73
- try:
74
- target_se, _ = se_extractor.get_se(
75
- default_speaker_audio,
76
- tone_color_converter,
77
- target_dir='processed',
78
- vad=True
79
- )
80
- print("Speaker embedding extracted successfully.")
81
- except Exception as e:
82
- raise RuntimeError(f"Failed to extract speaker embedding from {default_speaker_audio}: {str(e)}")
83
 
84
  # Supported languages
85
  supported_languages = ['zh', 'en']
86
 
87
- # Predict function (shared between FastAPI and Gradio)
88
  def predict(audio_file_pth, agree):
89
  text_hint = ''
90
  synthesized_audio_path = None
@@ -124,25 +124,19 @@ def predict(audio_file_pth, agree):
124
  print(f"Detected language: {language_predicted}")
125
 
126
  if language_predicted not in supported_languages:
127
- text_hint += f"[ERROR] The detected language '{language_predicted}' is not supported.\n"
128
  return (text_hint, None)
129
 
130
- # Select TTS model based on language
131
- if language_predicted == "zh":
132
- tts_model = zh_base_speaker_tts
133
- language = 'Chinese'
134
- speaker_style = 'default'
135
- else:
136
- tts_model = en_base_speaker_tts
137
- language = 'English'
138
- speaker_style = 'default'
139
 
140
  # Generate response using OpenAI GPT-4
141
  try:
142
  response = openai.chat.completions.create(
143
  model="gpt-4o-mini",
144
  messages=[
145
- {"role": "system", "content": "You are Mickey Mouse, a friendly character."},
146
  {"role": "user", "content": input_text}
147
  ]
148
  )
@@ -155,44 +149,25 @@ def predict(audio_file_pth, agree):
155
  # Synthesize reply text to audio
156
  try:
157
  src_path = os.path.join(output_dir, 'tmp_reply.wav')
158
- tts_model.tts(reply_text, src_path, speaker=speaker_style, language=language)
159
 
160
  save_path = os.path.join(output_dir, 'output_reply.wav')
161
  tone_color_converter.convert(
162
- audio_src_path=src_path,
163
- src_se=en_source_default_se if language == 'English' else zh_source_se,
164
  tgt_se=target_se,
165
  output_path=save_path
166
  )
167
 
168
- text_hint += "Response generated successfully."
169
  synthesized_audio_path = save_path
170
-
171
  except Exception as e:
172
  text_hint += f"[ERROR] Synthesis failed: {str(e)}\n"
 
173
  return (text_hint, None)
174
 
175
  return (text_hint, synthesized_audio_path)
176
 
177
-
178
- # FastAPI endpoint for prediction
179
- @app.post("/predict")
180
- async def predict_endpoint(file: UploadFile, agree: bool = Form(...)):
181
- # Save uploaded file
182
- temp_file_path = f"temp_{file.filename}"
183
- with open(temp_file_path, "wb") as temp_file:
184
- temp_file.write(await file.read())
185
-
186
- # Call predict
187
- info, audio_path = predict(temp_file_path, agree)
188
- os.remove(temp_file_path)
189
-
190
- if audio_path:
191
- return JSONResponse({"info": info, "audio": audio_path})
192
- else:
193
- return JSONResponse({"info": info}, status_code=400)
194
-
195
-
196
  # Gradio UI
197
  with gr.Blocks(analytics_enabled=False) as demo:
198
  gr.Markdown("# Mickey Mouse Voice Assistant")
@@ -209,5 +184,11 @@ with gr.Blocks(analytics_enabled=False) as demo:
209
 
210
  submit_button.click(predict, inputs=[audio_input, tos_checkbox], outputs=[info_output, audio_output])
211
 
212
- # Mount Gradio app to FastAPI
213
- mount_gradio_app(app, demo, path="/")
 
 
 
 
 
 
 
10
  import langid
11
  import traceback
12
  from dotenv import load_dotenv
 
 
 
13
 
14
  # Load environment variables
15
  load_dotenv()
16
 
17
+ # Global variables for preloaded resources
18
+ en_base_speaker_tts = None
19
+ zh_base_speaker_tts = None
20
+ tone_color_converter = None
21
+ target_se = None
22
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
23
 
24
  # Function to download and extract checkpoints
25
  def download_and_extract_checkpoints():
 
39
  os.remove(zip_path)
40
  print("Checkpoints are ready.")
41
 
42
+ # Initialize models and resources
43
+ def initialize_resources():
44
+ global en_base_speaker_tts, zh_base_speaker_tts, tone_color_converter, target_se
45
+ print("Initializing resources...")
46
 
47
+ # Download and extract checkpoints
48
+ download_and_extract_checkpoints()
 
 
49
 
50
+ # Define paths to checkpoints
51
+ en_ckpt_base = 'checkpoints/base_speakers/EN'
52
+ zh_ckpt_base = 'checkpoints/base_speakers/ZH'
53
+ ckpt_converter = 'checkpoints/converter'
54
+
55
+ # Load TTS models
56
+ en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
57
+ en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
58
+ zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
59
+ zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
60
+
61
+ # Load tone color converter
62
+ tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
63
+ tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
64
+
65
+ # Load speaker embeddings
66
+ en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
67
+ zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
68
+
69
+ # Extract speaker embedding from the default Mickey Mouse audio
70
+ default_speaker_audio = "resources/output.wav"
71
+ try:
72
+ target_se, _ = se_extractor.get_se(
73
+ default_speaker_audio,
74
+ tone_color_converter,
75
+ target_dir='processed',
76
+ vad=True
77
+ )
78
+ print("Speaker embedding extracted successfully.")
79
+ except Exception as e:
80
+ raise RuntimeError(f"Failed to extract speaker embedding from {default_speaker_audio}: {str(e)}")
81
+
82
+ initialize_resources()
 
83
 
84
  # Supported languages
85
  supported_languages = ['zh', 'en']
86
 
87
+ # Predict function
88
  def predict(audio_file_pth, agree):
89
  text_hint = ''
90
  synthesized_audio_path = None
 
124
  print(f"Detected language: {language_predicted}")
125
 
126
  if language_predicted not in supported_languages:
127
+ text_hint += f"[ERROR] Unsupported language: {language_predicted}\n"
128
  return (text_hint, None)
129
 
130
+ # Select TTS model
131
+ tts_model = zh_base_speaker_tts if language_predicted == "zh" else en_base_speaker_tts
132
+ language = 'Chinese' if language_predicted == "zh" else 'English'
 
 
 
 
 
 
133
 
134
  # Generate response using OpenAI GPT-4
135
  try:
136
  response = openai.chat.completions.create(
137
  model="gpt-4o-mini",
138
  messages=[
139
+ {"role": "system", "content": "You are Mickey Mouse, a cheerful character who responds to children's queries."},
140
  {"role": "user", "content": input_text}
141
  ]
142
  )
 
149
  # Synthesize reply text to audio
150
  try:
151
  src_path = os.path.join(output_dir, 'tmp_reply.wav')
152
+ tts_model.tts(reply_text, src_path, speaker='default', language=language)
153
 
154
  save_path = os.path.join(output_dir, 'output_reply.wav')
155
  tone_color_converter.convert(
156
+ audio_src_path=src_path,
157
+ src_se=target_se,
158
  tgt_se=target_se,
159
  output_path=save_path
160
  )
161
 
162
+ text_hint += "Response generated successfully.\n"
163
  synthesized_audio_path = save_path
 
164
  except Exception as e:
165
  text_hint += f"[ERROR] Synthesis failed: {str(e)}\n"
166
+ traceback.print_exc()
167
  return (text_hint, None)
168
 
169
  return (text_hint, synthesized_audio_path)
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # Gradio UI
172
  with gr.Blocks(analytics_enabled=False) as demo:
173
  gr.Markdown("# Mickey Mouse Voice Assistant")
 
184
 
185
  submit_button.click(predict, inputs=[audio_input, tos_checkbox], outputs=[info_output, audio_output])
186
 
187
+ demo.queue()
188
+ demo.launch(
189
+ server_name="0.0.0.0",
190
+ server_port=int(os.environ.get("PORT", 7860)),
191
+ debug=True,
192
+ show_api=True,
193
+ share=False
194
+ )