Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -14,13 +14,6 @@ from dotenv import load_dotenv
|
|
14 |
# Load environment variables
|
15 |
load_dotenv()
|
16 |
|
17 |
-
# Global variables for preloaded resources
|
18 |
-
en_base_speaker_tts = None
|
19 |
-
zh_base_speaker_tts = None
|
20 |
-
tone_color_converter = None
|
21 |
-
target_se = None
|
22 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
23 |
-
|
24 |
# Function to download and extract checkpoints
|
25 |
def download_and_extract_checkpoints():
|
26 |
zip_url = "https://huggingface.co/camenduru/OpenVoice/resolve/main/checkpoints_1226.zip"
|
@@ -39,52 +32,56 @@ def download_and_extract_checkpoints():
|
|
39 |
os.remove(zip_path)
|
40 |
print("Checkpoints are ready.")
|
41 |
|
42 |
-
#
|
43 |
-
|
44 |
-
global en_base_speaker_tts, zh_base_speaker_tts, tone_color_converter, target_se
|
45 |
-
print("Initializing resources...")
|
46 |
-
|
47 |
-
# Download and extract checkpoints
|
48 |
-
download_and_extract_checkpoints()
|
49 |
-
|
50 |
-
# Define paths to checkpoints
|
51 |
-
en_ckpt_base = 'checkpoints/base_speakers/EN'
|
52 |
-
zh_ckpt_base = 'checkpoints/base_speakers/ZH'
|
53 |
-
ckpt_converter = 'checkpoints/converter'
|
54 |
-
|
55 |
-
# Load TTS models
|
56 |
-
en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
|
57 |
-
en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
|
58 |
-
zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
|
59 |
-
zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
# Supported languages
|
85 |
supported_languages = ['zh', 'en']
|
86 |
|
87 |
-
# Predict function
|
88 |
def predict(audio_file_pth, agree):
|
89 |
text_hint = ''
|
90 |
synthesized_audio_path = None
|
@@ -98,7 +95,7 @@ def predict(audio_file_pth, agree):
|
|
98 |
if audio_file_pth is not None:
|
99 |
speaker_wav = audio_file_pth
|
100 |
else:
|
101 |
-
text_hint += "[ERROR] Please
|
102 |
return (text_hint, None)
|
103 |
|
104 |
# Transcribe audio to text using OpenAI Whisper
|
@@ -124,66 +121,105 @@ def predict(audio_file_pth, agree):
|
|
124 |
print(f"Detected language: {language_predicted}")
|
125 |
|
126 |
if language_predicted not in supported_languages:
|
127 |
-
text_hint += f"[ERROR]
|
128 |
return (text_hint, None)
|
129 |
|
130 |
-
# Select TTS model
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
|
|
134 |
# Generate response using OpenAI GPT-4
|
135 |
try:
|
136 |
response = openai.chat.completions.create(
|
137 |
model="gpt-4o-mini",
|
138 |
messages=[
|
139 |
-
{"role": "system", "content": "You are Mickey Mouse, a cheerful character who responds to children's queries."},
|
140 |
{"role": "user", "content": input_text}
|
141 |
-
]
|
|
|
|
|
|
|
|
|
142 |
)
|
143 |
-
|
|
|
144 |
print(f"GPT-4 Reply: {reply_text}")
|
145 |
except Exception as e:
|
146 |
-
text_hint += f"[ERROR] GPT-4
|
147 |
return (text_hint, None)
|
148 |
|
149 |
# Synthesize reply text to audio
|
150 |
try:
|
151 |
src_path = os.path.join(output_dir, 'tmp_reply.wav')
|
152 |
-
|
|
|
|
|
153 |
|
154 |
save_path = os.path.join(output_dir, 'output_reply.wav')
|
|
|
155 |
tone_color_converter.convert(
|
156 |
-
audio_src_path=src_path,
|
157 |
-
src_se=
|
158 |
tgt_se=target_se,
|
159 |
-
output_path=save_path
|
|
|
160 |
)
|
|
|
161 |
|
162 |
text_hint += "Response generated successfully.\n"
|
163 |
synthesized_audio_path = save_path
|
|
|
164 |
except Exception as e:
|
165 |
-
text_hint += f"[ERROR]
|
166 |
traceback.print_exc()
|
167 |
return (text_hint, None)
|
168 |
|
169 |
return (text_hint, synthesized_audio_path)
|
170 |
|
171 |
-
# Gradio UI
|
172 |
with gr.Blocks(analytics_enabled=False) as demo:
|
173 |
gr.Markdown("# Mickey Mouse Voice Assistant")
|
174 |
|
175 |
with gr.Row():
|
176 |
with gr.Column():
|
177 |
-
audio_input = gr.Audio(
|
178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
submit_button = gr.Button("Send")
|
180 |
|
181 |
with gr.Column():
|
182 |
-
info_output = gr.Textbox(
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
-
submit_button.click(
|
|
|
|
|
|
|
|
|
186 |
|
|
|
187 |
demo.queue()
|
188 |
demo.launch(
|
189 |
server_name="0.0.0.0",
|
|
|
14 |
# Load environment variables
|
15 |
load_dotenv()
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
# Function to download and extract checkpoints
|
18 |
def download_and_extract_checkpoints():
|
19 |
zip_url = "https://huggingface.co/camenduru/OpenVoice/resolve/main/checkpoints_1226.zip"
|
|
|
32 |
os.remove(zip_path)
|
33 |
print("Checkpoints are ready.")
|
34 |
|
35 |
+
# Call the function to ensure checkpoints are available
|
36 |
+
download_and_extract_checkpoints()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
+
# Initialize OpenAI API key
|
39 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
40 |
+
if not openai.api_key:
|
41 |
+
raise ValueError("Please set the OPENAI_API_KEY environment variable.")
|
42 |
|
43 |
+
parser = argparse.ArgumentParser()
|
44 |
+
parser.add_argument("--share", action='store_true', default=False, help="make link public")
|
45 |
+
args = parser.parse_args()
|
46 |
|
47 |
+
# Define paths to checkpoints
|
48 |
+
en_ckpt_base = 'checkpoints/base_speakers/EN'
|
49 |
+
zh_ckpt_base = 'checkpoints/base_speakers/ZH'
|
50 |
+
ckpt_converter = 'checkpoints/converter'
|
51 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
52 |
+
output_dir = 'outputs'
|
53 |
+
os.makedirs(output_dir, exist_ok=True)
|
54 |
+
|
55 |
+
# Load TTS models
|
56 |
+
en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
|
57 |
+
en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
|
58 |
+
zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
|
59 |
+
zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
|
60 |
+
|
61 |
+
tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
|
62 |
+
tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
|
63 |
+
|
64 |
+
# Load speaker embeddings
|
65 |
+
en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
|
66 |
+
en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device)
|
67 |
+
zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
|
68 |
+
|
69 |
+
# Extract speaker embedding from the default Mickey Mouse audio
|
70 |
+
default_speaker_audio = "resources/output.wav"
|
71 |
+
try:
|
72 |
+
target_se, _ = se_extractor.get_se(
|
73 |
+
default_speaker_audio,
|
74 |
+
tone_color_converter,
|
75 |
+
target_dir='processed',
|
76 |
+
vad=True
|
77 |
+
)
|
78 |
+
print("Speaker embedding extracted successfully.")
|
79 |
+
except Exception as e:
|
80 |
+
raise RuntimeError(f"Failed to extract speaker embedding from {default_speaker_audio}: {str(e)}")
|
81 |
|
82 |
# Supported languages
|
83 |
supported_languages = ['zh', 'en']
|
84 |
|
|
|
85 |
def predict(audio_file_pth, agree):
|
86 |
text_hint = ''
|
87 |
synthesized_audio_path = None
|
|
|
95 |
if audio_file_pth is not None:
|
96 |
speaker_wav = audio_file_pth
|
97 |
else:
|
98 |
+
text_hint += "[ERROR] Please record your voice using the Microphone.\n"
|
99 |
return (text_hint, None)
|
100 |
|
101 |
# Transcribe audio to text using OpenAI Whisper
|
|
|
121 |
print(f"Detected language: {language_predicted}")
|
122 |
|
123 |
if language_predicted not in supported_languages:
|
124 |
+
text_hint += f"[ERROR] The detected language '{language_predicted}' is not supported. Supported languages are: {supported_languages}\n"
|
125 |
return (text_hint, None)
|
126 |
|
127 |
+
# Select TTS model based on language
|
128 |
+
if language_predicted == "zh":
|
129 |
+
tts_model = zh_base_speaker_tts
|
130 |
+
language = 'Chinese'
|
131 |
+
speaker_style = 'default'
|
132 |
+
else:
|
133 |
+
tts_model = en_base_speaker_tts
|
134 |
+
language = 'English'
|
135 |
+
speaker_style = 'default'
|
136 |
|
137 |
+
# Generate response using OpenAI GPT-4
|
138 |
# Generate response using OpenAI GPT-4
|
139 |
try:
|
140 |
response = openai.chat.completions.create(
|
141 |
model="gpt-4o-mini",
|
142 |
messages=[
|
143 |
+
{"role": "system", "content": "You are Mickey Mouse, a friendly and cheerful character who responds to children's queries in a simple and engaging manner. Please keep your response up to 200 characters."},
|
144 |
{"role": "user", "content": input_text}
|
145 |
+
],
|
146 |
+
max_tokens=200,
|
147 |
+
n=1,
|
148 |
+
stop=None,
|
149 |
+
temperature=0.7,
|
150 |
)
|
151 |
+
# Correctly access the response content
|
152 |
+
reply_text = response.choices[0].message.content.strip()
|
153 |
print(f"GPT-4 Reply: {reply_text}")
|
154 |
except Exception as e:
|
155 |
+
text_hint += f"[ERROR] Failed to get response from OpenAI GPT-4: {str(e)}\n"
|
156 |
return (text_hint, None)
|
157 |
|
158 |
# Synthesize reply text to audio
|
159 |
try:
|
160 |
src_path = os.path.join(output_dir, 'tmp_reply.wav')
|
161 |
+
|
162 |
+
tts_model.tts(reply_text, src_path, speaker=speaker_style, language=language)
|
163 |
+
print(f"Audio synthesized and saved to {src_path}")
|
164 |
|
165 |
save_path = os.path.join(output_dir, 'output_reply.wav')
|
166 |
+
|
167 |
tone_color_converter.convert(
|
168 |
+
audio_src_path=src_path,
|
169 |
+
src_se=en_source_default_se if language == 'English' else zh_source_se,
|
170 |
tgt_se=target_se,
|
171 |
+
output_path=save_path,
|
172 |
+
message="@MickeyMouse"
|
173 |
)
|
174 |
+
print(f"Tone color conversion completed and saved to {save_path}")
|
175 |
|
176 |
text_hint += "Response generated successfully.\n"
|
177 |
synthesized_audio_path = save_path
|
178 |
+
|
179 |
except Exception as e:
|
180 |
+
text_hint += f"[ERROR] Failed to synthesize audio: {str(e)}\n"
|
181 |
traceback.print_exc()
|
182 |
return (text_hint, None)
|
183 |
|
184 |
return (text_hint, synthesized_audio_path)
|
185 |
|
|
|
186 |
with gr.Blocks(analytics_enabled=False) as demo:
|
187 |
gr.Markdown("# Mickey Mouse Voice Assistant")
|
188 |
|
189 |
with gr.Row():
|
190 |
with gr.Column():
|
191 |
+
audio_input = gr.Audio(
|
192 |
+
source="microphone",
|
193 |
+
type="filepath",
|
194 |
+
label="Record Your Voice",
|
195 |
+
info="Click the microphone button to record your voice."
|
196 |
+
)
|
197 |
+
tos_checkbox = gr.Checkbox(
|
198 |
+
label="Agree to Terms & Conditions",
|
199 |
+
value=False,
|
200 |
+
info="I agree to the terms of service."
|
201 |
+
)
|
202 |
submit_button = gr.Button("Send")
|
203 |
|
204 |
with gr.Column():
|
205 |
+
info_output = gr.Textbox(
|
206 |
+
label="Info",
|
207 |
+
interactive=False,
|
208 |
+
lines=4,
|
209 |
+
)
|
210 |
+
audio_output = gr.Audio(
|
211 |
+
label="Mickey's Response",
|
212 |
+
interactive=False,
|
213 |
+
autoplay=True,
|
214 |
+
)
|
215 |
|
216 |
+
submit_button.click(
|
217 |
+
predict,
|
218 |
+
inputs=[audio_input, tos_checkbox],
|
219 |
+
outputs=[info_output, audio_output]
|
220 |
+
)
|
221 |
|
222 |
+
# Launch the Gradio app
|
223 |
demo.queue()
|
224 |
demo.launch(
|
225 |
server_name="0.0.0.0",
|