ygauravyy commited on
Commit
fff6648
1 Parent(s): 27386e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -5
app.py CHANGED
@@ -1,7 +1,226 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- app = FastAPI()
 
4
 
5
- @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import argparse
4
+ import gradio as gr
5
+ import openai
6
+ from zipfile import ZipFile
7
+ import requests
8
+ import se_extractor
9
+ from api import BaseSpeakerTTS, ToneColorConverter
10
+ import langid
11
+ import traceback
12
+ from dotenv import load_dotenv
13
 
14
+ # Load environment variables
15
+ load_dotenv()
16
 
17
+ # Function to download and extract checkpoints
18
+ def download_and_extract_checkpoints():
19
+ zip_url = "https://huggingface.co/camenduru/OpenVoice/resolve/main/checkpoints_1226.zip"
20
+ zip_path = "checkpoints.zip"
21
+
22
+ if not os.path.exists("checkpoints"):
23
+ print("Downloading checkpoints...")
24
+ response = requests.get(zip_url, stream=True)
25
+ with open(zip_path, "wb") as zip_file:
26
+ for chunk in response.iter_content(chunk_size=8192):
27
+ if chunk:
28
+ zip_file.write(chunk)
29
+ print("Extracting checkpoints...")
30
+ with ZipFile(zip_path, "r") as zip_ref:
31
+ zip_ref.extractall(".")
32
+ os.remove(zip_path)
33
+ print("Checkpoints are ready.")
34
+
35
+ # Call the function to ensure checkpoints are available
36
+ download_and_extract_checkpoints()
37
+
38
+ # Initialize OpenAI API key
39
+ openai.api_key = os.getenv("OPENAI_API_KEY")
40
+ if not openai.api_key:
41
+ raise ValueError("Please set the OPENAI_API_KEY environment variable.")
42
+
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument("--share", action='store_true', default=False, help="make link public")
45
+ args = parser.parse_args()
46
+
47
+ # Define paths to checkpoints
48
+ en_ckpt_base = 'checkpoints/base_speakers/EN'
49
+ zh_ckpt_base = 'checkpoints/base_speakers/ZH'
50
+ ckpt_converter = 'checkpoints/converter'
51
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
52
+ output_dir = 'outputs'
53
+ os.makedirs(output_dir, exist_ok=True)
54
+
55
+ # Load TTS models
56
+ en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
57
+ en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
58
+ zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
59
+ zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
60
+
61
+ tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
62
+ tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
63
+
64
+ # Load speaker embeddings
65
+ en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
66
+ en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device)
67
+ zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
68
+
69
+ # Extract speaker embedding from the default Mickey Mouse audio
70
+ default_speaker_audio = "resources/output.wav"
71
+ try:
72
+ target_se, _ = se_extractor.get_se(
73
+ default_speaker_audio,
74
+ tone_color_converter,
75
+ target_dir='processed',
76
+ vad=True
77
+ )
78
+ print("Speaker embedding extracted successfully.")
79
+ except Exception as e:
80
+ raise RuntimeError(f"Failed to extract speaker embedding from {default_speaker_audio}: {str(e)}")
81
+
82
+ # Supported languages
83
+ supported_languages = ['zh', 'en']
84
+
85
+ def predict(audio_file_pth, agree):
86
+ text_hint = ''
87
+ synthesized_audio_path = None
88
+
89
+ # Agree with the terms
90
+ if not agree:
91
+ text_hint += '[ERROR] Please accept the Terms & Conditions!\n'
92
+ return (text_hint, None)
93
+
94
+ # Check if audio file is provided
95
+ if audio_file_pth is not None:
96
+ speaker_wav = audio_file_pth
97
+ else:
98
+ text_hint += "[ERROR] Please record your voice using the Microphone.\n"
99
+ return (text_hint, None)
100
+
101
+ # Transcribe audio to text using OpenAI Whisper
102
+ try:
103
+ with open(speaker_wav, 'rb') as audio_file:
104
+ transcription_response = openai.Audio.transcribe(
105
+ model="whisper-1",
106
+ file=audio_file,
107
+ response_format='text'
108
+ )
109
+ input_text = transcription_response.strip()
110
+ print(f"Transcribed Text: {input_text}")
111
+ except Exception as e:
112
+ text_hint += f"[ERROR] Transcription failed: {str(e)}\n"
113
+ return (text_hint, None)
114
+
115
+ if len(input_text) == 0:
116
+ text_hint += "[ERROR] No speech detected in the audio.\n"
117
+ return (text_hint, None)
118
+
119
+ # Detect language
120
+ language_predicted = langid.classify(input_text)[0].strip()
121
+ print(f"Detected language: {language_predicted}")
122
+
123
+ if language_predicted not in supported_languages:
124
+ text_hint += f"[ERROR] The detected language '{language_predicted}' is not supported. Supported languages are: {supported_languages}\n"
125
+ return (text_hint, None)
126
+
127
+ # Select TTS model based on language
128
+ if language_predicted == "zh":
129
+ tts_model = zh_base_speaker_tts
130
+ language = 'Chinese'
131
+ speaker_style = 'default'
132
+ else:
133
+ tts_model = en_base_speaker_tts
134
+ language = 'English'
135
+ speaker_style = 'default'
136
+
137
+ # Generate response using OpenAI GPT-4
138
+ try:
139
+ response = openai.ChatCompletion.create(
140
+ model="gpt-4o-mini",
141
+ messages=[
142
+ {"role": "system", "content": "You are Mickey Mouse, a friendly and cheerful character who responds to children's queries in a simple and engaging manner. Please keep your response up to 200 characters."},
143
+ {"role": "user", "content": input_text}
144
+ ],
145
+ max_tokens=200,
146
+ temperature=0.7,
147
+ )
148
+ reply_text = response['choices'][0]['message']['content'].strip()
149
+ print(f"GPT-4 Reply: {reply_text}")
150
+ except Exception as e:
151
+ text_hint += f"[ERROR] Failed to get response from OpenAI GPT-4: {str(e)}\n"
152
+ return (text_hint, None)
153
+
154
+ # Synthesize reply text to audio
155
+ try:
156
+ src_path = os.path.join(output_dir, 'tmp_reply.wav')
157
+
158
+ tts_model.tts(reply_text, src_path, speaker=speaker_style, language=language)
159
+ print(f"Audio synthesized and saved to {src_path}")
160
+
161
+ save_path = os.path.join(output_dir, 'output_reply.wav')
162
+
163
+ tone_color_converter.convert(
164
+ audio_src_path=src_path,
165
+ src_se=en_source_default_se if language == 'English' else zh_source_se,
166
+ tgt_se=target_se,
167
+ output_path=save_path,
168
+ message="@MickeyMouse"
169
+ )
170
+ print(f"Tone color conversion completed and saved to {save_path}")
171
+
172
+ text_hint += "Response generated successfully.\n"
173
+ synthesized_audio_path = save_path
174
+
175
+ except Exception as e:
176
+ text_hint += f"[ERROR] Failed to synthesize audio: {str(e)}\n"
177
+ traceback.print_exc()
178
+ return (text_hint, None)
179
+
180
+ return (text_hint, synthesized_audio_path)
181
+
182
+ with gr.Blocks(analytics_enabled=False) as demo:
183
+ gr.Markdown("# Mickey Mouse Voice Assistant")
184
+
185
+ with gr.Row():
186
+ with gr.Column():
187
+ audio_input = gr.Audio(
188
+ source="microphone",
189
+ type="filepath",
190
+ label="Record Your Voice",
191
+ info="Click the microphone button to record your voice."
192
+ )
193
+ tos_checkbox = gr.Checkbox(
194
+ label="Agree to Terms & Conditions",
195
+ value=False,
196
+ info="I agree to the terms of service."
197
+ )
198
+ submit_button = gr.Button("Send")
199
+
200
+ with gr.Column():
201
+ info_output = gr.Textbox(
202
+ label="Info",
203
+ interactive=False,
204
+ lines=4,
205
+ )
206
+ audio_output = gr.Audio(
207
+ label="Mickey's Response",
208
+ interactive=False,
209
+ autoplay=True,
210
+ )
211
+
212
+ submit_button.click(
213
+ predict,
214
+ inputs=[audio_input, tos_checkbox],
215
+ outputs=[info_output, audio_output]
216
+ )
217
+
218
+ # Launch the Gradio app
219
+ demo.queue()
220
+ demo.launch(
221
+ server_name="0.0.0.0",
222
+ server_port=int(os.environ.get("PORT", 7860)),
223
+ debug=True,
224
+ show_api=True,
225
+ share=False
226
+ )