Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -17,7 +17,6 @@ import uvicorn
|
|
17 |
# Load environment variables
|
18 |
load_dotenv()
|
19 |
|
20 |
-
# Function to download and extract checkpoints
|
21 |
def download_and_extract_checkpoints():
|
22 |
zip_url = "https://huggingface.co/camenduru/OpenVoice/resolve/main/checkpoints_1226.zip"
|
23 |
zip_path = "checkpoints.zip"
|
@@ -43,7 +42,6 @@ openai.api_key = os.getenv("OPENAI_API_KEY")
|
|
43 |
if not openai.api_key:
|
44 |
raise ValueError("Please set the OPENAI_API_KEY environment variable.")
|
45 |
|
46 |
-
# Define paths to checkpoints
|
47 |
en_ckpt_base = 'checkpoints/base_speakers/EN'
|
48 |
zh_ckpt_base = 'checkpoints/base_speakers/ZH'
|
49 |
ckpt_converter = 'checkpoints/converter'
|
@@ -51,7 +49,6 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
51 |
output_dir = 'outputs'
|
52 |
os.makedirs(output_dir, exist_ok=True)
|
53 |
|
54 |
-
# Load TTS models
|
55 |
en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
|
56 |
en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
|
57 |
zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
|
@@ -60,12 +57,10 @@ zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
|
|
60 |
tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
|
61 |
tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
|
62 |
|
63 |
-
# Load speaker embeddings
|
64 |
en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
|
65 |
en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device)
|
66 |
zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
|
67 |
|
68 |
-
# Extract speaker embedding from the default Mickey Mouse audio
|
69 |
default_speaker_audio = "resources/output.wav"
|
70 |
try:
|
71 |
target_se, _ = se_extractor.get_se(
|
@@ -78,19 +73,16 @@ try:
|
|
78 |
except Exception as e:
|
79 |
raise RuntimeError(f"Failed to extract speaker embedding from {default_speaker_audio}: {str(e)}")
|
80 |
|
81 |
-
# Supported languages
|
82 |
supported_languages = ['zh', 'en']
|
83 |
|
84 |
def predict(audio_file_pth, agree):
|
85 |
text_hint = ''
|
86 |
synthesized_audio_path = None
|
87 |
|
88 |
-
# Agree with the terms
|
89 |
if not agree:
|
90 |
text_hint += '[ERROR] Please accept the Terms & Conditions!\n'
|
91 |
return (text_hint, None)
|
92 |
|
93 |
-
# Check if audio file is provided
|
94 |
if audio_file_pth is not None:
|
95 |
speaker_wav = audio_file_pth
|
96 |
else:
|
@@ -115,7 +107,6 @@ def predict(audio_file_pth, agree):
|
|
115 |
text_hint += "[ERROR] No speech detected in the audio.\n"
|
116 |
return (text_hint, None)
|
117 |
|
118 |
-
# Detect language
|
119 |
language_predicted = langid.classify(input_text)[0].strip()
|
120 |
print(f"Detected language: {language_predicted}")
|
121 |
|
@@ -123,7 +114,6 @@ def predict(audio_file_pth, agree):
|
|
123 |
text_hint += f"[ERROR] The detected language '{language_predicted}' is not supported. Supported languages are: {supported_languages}\n"
|
124 |
return (text_hint, None)
|
125 |
|
126 |
-
# Select TTS model based on language
|
127 |
if language_predicted == "zh":
|
128 |
tts_model = zh_base_speaker_tts
|
129 |
language = 'Chinese'
|
@@ -133,7 +123,6 @@ def predict(audio_file_pth, agree):
|
|
133 |
language = 'English'
|
134 |
speaker_style = 'default'
|
135 |
|
136 |
-
# Generate response using OpenAI GPT-4
|
137 |
try:
|
138 |
response = openai.chat.completions.create(
|
139 |
model="gpt-4o-mini",
|
@@ -152,10 +141,8 @@ def predict(audio_file_pth, agree):
|
|
152 |
text_hint += f"[ERROR] Failed to get response from OpenAI GPT-4: {str(e)}\n"
|
153 |
return (text_hint, None)
|
154 |
|
155 |
-
# Synthesize reply text to audio
|
156 |
try:
|
157 |
src_path = os.path.join(output_dir, 'tmp_reply.wav')
|
158 |
-
|
159 |
tts_model.tts(reply_text, src_path, speaker=speaker_style, language=language)
|
160 |
print(f"Audio synthesized and saved to {src_path}")
|
161 |
|
@@ -172,7 +159,6 @@ def predict(audio_file_pth, agree):
|
|
172 |
|
173 |
text_hint += "Response generated successfully.\n"
|
174 |
synthesized_audio_path = save_path
|
175 |
-
|
176 |
except Exception as e:
|
177 |
text_hint += f"[ERROR] Failed to synthesize audio: {str(e)}\n"
|
178 |
traceback.print_exc()
|
@@ -184,7 +170,6 @@ app = FastAPI()
|
|
184 |
|
185 |
@app.post("/predict")
|
186 |
async def predict_endpoint(agree: bool = Form(...), audio_file: UploadFile = File(...)):
|
187 |
-
# Save the uploaded file locally
|
188 |
temp_dir = "temp"
|
189 |
os.makedirs(temp_dir, exist_ok=True)
|
190 |
audio_path = os.path.join(temp_dir, audio_file.filename)
|
@@ -193,12 +178,9 @@ async def predict_endpoint(agree: bool = Form(...), audio_file: UploadFile = Fil
|
|
193 |
|
194 |
info, audio_output_path = predict(audio_path, agree)
|
195 |
if audio_output_path:
|
196 |
-
# Return a JSON response with info and a path to the audio file.
|
197 |
-
# You could return the file content as base64 if you prefer.
|
198 |
return JSONResponse(content={"info": info, "audio_path": audio_output_path})
|
199 |
else:
|
200 |
return JSONResponse(content={"info": info, "audio_path": None}, status_code=400)
|
201 |
|
202 |
-
|
203 |
if __name__ == "__main__":
|
204 |
-
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))
|
|
|
17 |
# Load environment variables
|
18 |
load_dotenv()
|
19 |
|
|
|
20 |
def download_and_extract_checkpoints():
|
21 |
zip_url = "https://huggingface.co/camenduru/OpenVoice/resolve/main/checkpoints_1226.zip"
|
22 |
zip_path = "checkpoints.zip"
|
|
|
42 |
if not openai.api_key:
|
43 |
raise ValueError("Please set the OPENAI_API_KEY environment variable.")
|
44 |
|
|
|
45 |
en_ckpt_base = 'checkpoints/base_speakers/EN'
|
46 |
zh_ckpt_base = 'checkpoints/base_speakers/ZH'
|
47 |
ckpt_converter = 'checkpoints/converter'
|
|
|
49 |
output_dir = 'outputs'
|
50 |
os.makedirs(output_dir, exist_ok=True)
|
51 |
|
|
|
52 |
en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
|
53 |
en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
|
54 |
zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
|
|
|
57 |
tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
|
58 |
tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
|
59 |
|
|
|
60 |
en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
|
61 |
en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device)
|
62 |
zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
|
63 |
|
|
|
64 |
default_speaker_audio = "resources/output.wav"
|
65 |
try:
|
66 |
target_se, _ = se_extractor.get_se(
|
|
|
73 |
except Exception as e:
|
74 |
raise RuntimeError(f"Failed to extract speaker embedding from {default_speaker_audio}: {str(e)}")
|
75 |
|
|
|
76 |
supported_languages = ['zh', 'en']
|
77 |
|
78 |
def predict(audio_file_pth, agree):
|
79 |
text_hint = ''
|
80 |
synthesized_audio_path = None
|
81 |
|
|
|
82 |
if not agree:
|
83 |
text_hint += '[ERROR] Please accept the Terms & Conditions!\n'
|
84 |
return (text_hint, None)
|
85 |
|
|
|
86 |
if audio_file_pth is not None:
|
87 |
speaker_wav = audio_file_pth
|
88 |
else:
|
|
|
107 |
text_hint += "[ERROR] No speech detected in the audio.\n"
|
108 |
return (text_hint, None)
|
109 |
|
|
|
110 |
language_predicted = langid.classify(input_text)[0].strip()
|
111 |
print(f"Detected language: {language_predicted}")
|
112 |
|
|
|
114 |
text_hint += f"[ERROR] The detected language '{language_predicted}' is not supported. Supported languages are: {supported_languages}\n"
|
115 |
return (text_hint, None)
|
116 |
|
|
|
117 |
if language_predicted == "zh":
|
118 |
tts_model = zh_base_speaker_tts
|
119 |
language = 'Chinese'
|
|
|
123 |
language = 'English'
|
124 |
speaker_style = 'default'
|
125 |
|
|
|
126 |
try:
|
127 |
response = openai.chat.completions.create(
|
128 |
model="gpt-4o-mini",
|
|
|
141 |
text_hint += f"[ERROR] Failed to get response from OpenAI GPT-4: {str(e)}\n"
|
142 |
return (text_hint, None)
|
143 |
|
|
|
144 |
try:
|
145 |
src_path = os.path.join(output_dir, 'tmp_reply.wav')
|
|
|
146 |
tts_model.tts(reply_text, src_path, speaker=speaker_style, language=language)
|
147 |
print(f"Audio synthesized and saved to {src_path}")
|
148 |
|
|
|
159 |
|
160 |
text_hint += "Response generated successfully.\n"
|
161 |
synthesized_audio_path = save_path
|
|
|
162 |
except Exception as e:
|
163 |
text_hint += f"[ERROR] Failed to synthesize audio: {str(e)}\n"
|
164 |
traceback.print_exc()
|
|
|
170 |
|
171 |
@app.post("/predict")
|
172 |
async def predict_endpoint(agree: bool = Form(...), audio_file: UploadFile = File(...)):
|
|
|
173 |
temp_dir = "temp"
|
174 |
os.makedirs(temp_dir, exist_ok=True)
|
175 |
audio_path = os.path.join(temp_dir, audio_file.filename)
|
|
|
178 |
|
179 |
info, audio_output_path = predict(audio_path, agree)
|
180 |
if audio_output_path:
|
|
|
|
|
181 |
return JSONResponse(content={"info": info, "audio_path": audio_output_path})
|
182 |
else:
|
183 |
return JSONResponse(content={"info": info, "audio_path": None}, status_code=400)
|
184 |
|
|
|
185 |
if __name__ == "__main__":
|
186 |
+
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
|