srttosubedge / app.py
leetuan023's picture
Create app.py
52dff6c verified
raw
history blame
7.96 kB
#!/usr/bin/env python3
import os
import shutil
import subprocess
import tempfile
import asyncio
import edge_tts
import pysrt
import logging
import random
import gradio as gr
# Logging setup
logger = logging.getLogger(__name__)
FORMAT = "[%(asctime)s %(filename)s->%(funcName)s():%(lineno)s]%(levelname)s: %(message)s"
logging.basicConfig(format=FORMAT)
# Function for dependency check (ffmpeg, ffprobe)
def dep_check():
if not shutil.which("ffmpeg"):
raise RuntimeError("ffmpeg is not installed")
if not shutil.which("ffprobe"):
raise RuntimeError("ffprobe (part of ffmpeg) is not installed")
# Function to convert SRT time to seconds
def pysrttime_to_seconds(t):
return (t.hours * 60 + t.minutes) * 60 + t.seconds + t.milliseconds / 1000
# Get the duration of an audio/video file
def get_duration(in_file):
duration = subprocess.check_output(
[
"ffprobe",
"-v", "error",
"-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1",
in_file,
]
).decode("utf-8")
return float(duration)
# Ensure the audio file matches the specified length
def ensure_audio_length(in_file, out_file, length):
duration = get_duration(in_file)
atempo = duration / length
if atempo < 0.5:
atempo = 0.5
elif atempo > 100:
atempo = 100
if atempo > 1:
retcode = subprocess.call(
[
"ffmpeg", "-y", "-i", in_file, "-filter:a", f"atempo={atempo}", out_file
],
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
if retcode != 0:
raise subprocess.CalledProcessError(retcode, "ffmpeg")
else:
shutil.copyfile(in_file, out_file)
# Function to generate silence
def silence_gen(out_file, duration):
retcode = subprocess.call(
[
"ffmpeg", "-y", "-f", "lavfi", "-i", "anullsrc=cl=mono:r=24000", "-t", str(duration), out_file
],
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
if retcode != 0:
raise subprocess.CalledProcessError(retcode, "ffmpeg")
# Handle enhanced SRT parameters (rate, volume, voice)
def get_enhanced_srt_params(text, arg):
text_ = text.split("\n")[-1]
if text_.startswith("edge_tts{") and text_.endswith("}"):
text_ = text_[len("edge_tts{") : -len("}")]
text_ = text_.split(",")
text_ = dict([x.split(":") for x in text_])
for x in text_.keys():
if x not in ["rate", "volume", "voice"]:
raise ValueError("edge_tts{} is invalid")
for k, v in text_.items():
arg[k] = v
return arg, "\n".join(text.split("\n")[:-1])
return arg, text
# Asynchronous audio generation
async def audio_gen(queue):
retry_count = 0
retry_limit = 5
arg = await queue.get()
fname, text, duration, enhanced_srt = arg["fname"], arg["text"], arg["duration"], arg["enhanced_srt"]
if enhanced_srt:
arg, text = get_enhanced_srt_params(text, arg)
text = " ".join(text.split("\n"))
while True:
try:
communicate = edge_tts.Communicate(text, rate=arg["rate"], volume=arg["volume"], voice=arg["voice"])
await communicate.save(fname)
except edge_tts.exceptions.NoAudioReceived:
with open(fname, "wb") as fobj:
fobj.write(b"")
except Exception as e:
if retry_count > retry_limit:
raise Exception(f"Too many retries for {fname}") from e
retry_count += 1
await asyncio.sleep(retry_count + random.randint(1, 5))
continue
break
file_length = os.path.getsize(fname)
if file_length > 0:
temporary_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
try:
ensure_audio_length(fname, temporary_file.name, duration)
finally:
temporary_file.close()
shutil.move(temporary_file.name, fname)
else:
silence_gen(fname, duration)
queue.task_done()
# Main async processing logic
async def _main(srt_data, voice, rate, volume, batch_size, enhanced_srt):
max_duration = pysrttime_to_seconds(srt_data[-1].end)
input_files = []
input_files_start_end = {}
with tempfile.TemporaryDirectory() as temp_dir:
args = []
queue = asyncio.Queue()
for i, j in enumerate(srt_data):
fname = os.path.join(temp_dir, f"{i}.mp3")
input_files.append(fname)
start = pysrttime_to_seconds(j.start)
end = pysrttime_to_seconds(j.end)
input_files_start_end[fname] = (start, end)
duration = pysrttime_to_seconds(j.duration)
args.append(
{
"fname": fname,
"text": j.text,
"rate": rate,
"volume": volume,
"voice": voice,
"duration": duration,
"enhanced_srt": enhanced_srt,
}
)
args_len = len(args)
for i in range(0, args_len, batch_size):
tasks = []
for j in range(i, min(i + batch_size, args_len)):
tasks.append(audio_gen(queue))
await queue.put(args[j])
for f in asyncio.as_completed(tasks):
await f
output_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name
f = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
try:
last_end = 0
for i, j in enumerate(input_files):
start = input_files_start_end[j][0]
needed = start - last_end
if needed > 0.0001:
sfname = os.path.join(temp_dir, f"silence_{i}.mp3")
silence_gen(sfname, needed)
f.write(f"file '{sfname}'\n")
last_end += get_duration(sfname)
f.write(f"file '{j}'\n")
last_end += get_duration(j)
f.flush()
f.close()
retcode = subprocess.call(
[
"ffmpeg",
"-y", "-f", "concat", "-safe", "0", "-i", f.name, "-c", "copy", output_file
],
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
if retcode != 0:
raise subprocess.CalledProcessError(retcode, "ffmpeg")
finally:
f.close()
os.remove(f.name)
return output_file
# Gradio Interface
def process_srt_to_mp3(srt_file, voice, speed, volume, batch_size, enhanced_srt):
srt_data = pysrt.from_string(srt_file.read().decode("utf-8"))
output_file = asyncio.run(
_main(
srt_data=srt_data,
voice=voice,
rate=speed,
volume=volume,
batch_size=batch_size,
enhanced_srt=enhanced_srt
)
)
return output_file
# Gradio UI elements
def create_ui():
voice_options = ["en-US-AriaNeural", "en-US-JennyNeural"]
interface = gr.Interface(
fn=process_srt_to_mp3,
inputs=[
gr.File(label="Upload SRT File"),
gr.Dropdown(voice_options, label="Voice", value="en-US-AriaNeural"),
gr.Textbox(value="+0%", label="Speech Rate (default +0%)"),
gr.Textbox(value="+0%", label="Volume (default +0%)"),
gr.Slider(1, 100, value=50, label="Batch Size"),
gr.Checkbox(value=True, label="Enable Enhanced SRT")
],
outputs=gr.File(label="Generated MP3 File"),
title="SRT to MP3 Converter",
description="Converts SRT files to MP3 using Edge TTS and FFmpeg"
)
return interface
# Launch Gradio interface
if __name__ == "__main__":
dep_check()
create_ui().launch()