Spaces:
Running
on
Zero
Running
on
Zero
from matplotlib import pyplot as plt | |
from accelerate import Accelerator | |
from zero import zero | |
import gradio as gr | |
from typing import Tuple | |
import os | |
from os import path | |
from utils import plot_spec | |
import librosa | |
from hashlib import md5 | |
from demucs.separate import main as demucs | |
from pyannote.audio import Pipeline | |
from json import dumps, loads | |
import shutil | |
import zipfile | |
accelerator = Accelerator() | |
device = accelerator.device | |
print(f"Running on {device}") | |
pipeline = Pipeline.from_pretrained( | |
"pyannote/speaker-diarization-3.1", use_auth_token=os.environ["HF_TOKEN"] | |
) | |
pipeline.to(device) | |
tasks = [] | |
os.makedirs("task", exist_ok=True) | |
for task in os.listdir("task"): | |
if path.isdir(path.join("task", task)): | |
tasks.append(task) | |
def gen_task_id(location: str): | |
# use md5 hash of video file as task id | |
video = open(location, "rb").read() | |
return md5(video).hexdigest() | |
def extract_audio(video: str) -> Tuple[str, str]: | |
task_id = gen_task_id(video) | |
os.makedirs(path.join("task", task_id), exist_ok=True) | |
videodest = path.join("task", task_id, "video.mp4") | |
if not path.exists(videodest): | |
shutil.copy(video, videodest) | |
wav48k = path.join("task", task_id, "extracted_48k.wav") | |
if not path.exists(wav48k): | |
os.system( | |
f"ffmpeg -i {videodest} -vn -ar 48000 task/{task_id}/extracted_48k.wav" | |
) | |
return (task_id, wav48k) | |
def extract_audio_post(task_id: str) -> str: | |
wav48k = path.join("task", task_id, "extracted_48k.wav") | |
if not path.exists(wav48k): | |
raise gr.Error("Audio file not found") | |
spec = path.join("task", task_id, "extracted_48k.png") | |
if not path.exists(spec): | |
y, sr = librosa.load(wav48k, sr=16000) | |
fig = plot_spec(y, sr) | |
fig.savefig(path.join("task", task_id, "extracted_48k.png")) | |
plt.close(fig) | |
return spec | |
def extract_vocals(task_id: str) -> str: | |
audio = path.join("task", task_id, "extracted_48k.wav") | |
if not path.exists(audio): | |
raise gr.Error("Audio file not found") | |
vocals = path.join("task", task_id, "htdemucs", "extracted_48k", "vocals.wav") | |
if not path.exists(vocals): | |
demucs( | |
[ | |
"-d", | |
str(device), | |
"-n", | |
"htdemucs", | |
"--two-stems", | |
"vocals", | |
"-o", | |
path.join("task", task_id), | |
audio, | |
] | |
) | |
return vocals | |
def extract_vocals_post(task_id: str) -> str: | |
vocals = path.join("task", task_id, "htdemucs", "extracted_48k", "vocals.wav") | |
if not path.exists(vocals): | |
raise gr.Error("Vocals file not found") | |
spec = path.join("task", task_id, "vocals.png") | |
if not path.exists(spec): | |
y, sr = librosa.load(vocals, sr=16000) | |
fig = plot_spec(y, sr) | |
fig.savefig(path.join("task", task_id, "vocals.png")) | |
plt.close(fig) | |
return spec | |
def diarize_audio(task_id: str): | |
vocals = path.join("task", task_id, "htdemucs", "extracted_48k", "vocals.wav") | |
if not path.exists(vocals): | |
raise gr.Error("Vocals file not found") | |
diarization_json = path.join("task", task_id, "diarization.json") | |
if not path.exists(diarization_json): | |
result = pipeline(vocals) | |
with open(diarization_json, "w") as f: | |
diarization = [] | |
for turn, _, speaker in result.itertracks(yield_label=True): | |
diarization.append( | |
{ | |
"speaker": speaker, | |
"start": turn.start, | |
"end": turn.end, | |
"duration": turn.duration, | |
} | |
) | |
f.write(dumps(diarization)) | |
with open(diarization_json, "r") as f: | |
diarization = loads(f.read()) | |
filtered_json = path.join("task", task_id, "filtered_diarization.json") | |
if not path.exists(filtered_json): | |
# Filter out segments shorter than 2 second and group by speaker | |
filtered_segments = {} | |
for turn in diarization: | |
speaker = turn["speaker"] | |
if turn["duration"] >= 2.0: | |
if speaker not in filtered_segments: | |
filtered_segments[speaker] = [] | |
filtered_segments[speaker].append(turn) | |
# Filter out speakers with less than 60 seconds of speech | |
filtered_segments = { | |
speaker: segments | |
for speaker, segments in filtered_segments.items() | |
if sum(segment["duration"] for segment in segments) >= 60 | |
} | |
with open(filtered_json, "w") as f: | |
f.write(dumps(filtered_segments)) | |
with open(filtered_json, "r") as f: | |
filtered_segments = loads(f.read()) | |
return filtered_segments | |
def generate_clips(task_id: str, speaker: str) -> Tuple[str, str, str]: | |
video = path.join("task", task_id, "video.mp4") | |
if not path.exists(video): | |
raise gr.Error("Video file not found") | |
filtered_json = path.join("task", task_id, "filtered_diarization.json") | |
if not path.exists(filtered_json): | |
raise gr.Error("Diarization not found") | |
with open(filtered_json, "r") as f: | |
filtered_segments = loads(f.read()) | |
if speaker not in filtered_segments: | |
raise gr.Error("Speaker not found") | |
mp4 = path.join("task", task_id, f"{speaker}.mp4") | |
if not path.exists(mp4): | |
cmd = f'ffmpeg -i {video} -filter_complex "' | |
for i, segment in enumerate(filtered_segments[speaker]): | |
start = segment["start"] | |
end = segment["end"] | |
cmd += f"[0:v]trim=start={start}:end={end},setpts=PTS-STARTPTS[v{i}];" | |
cmd += f"[0:a]atrim=start={start}:end={end},asetpts=PTS-STARTPTS[a{i}];" | |
for i in range(len(filtered_segments[speaker])): | |
cmd += f"[v{i}][a{i}]" | |
cmd += f'concat=n={len(filtered_segments[speaker])}:v=1:a=1[outv][outa]" -map [outv] -map [outa] -y {mp4}' | |
os.system(cmd) | |
segments = path.join("task", task_id, f"{speaker}") | |
if not path.exists(segments): | |
os.makedirs(segments) | |
for i, segment in enumerate(filtered_segments[speaker]): | |
start = segment["start"] | |
end = segment["end"] | |
name = path.join(segments, f"{i}_{start:.2f}_{end:.2f}.wav") | |
cmd = f"ffmpeg -i {video} -ss {start} -to {end} -f wav {name}" | |
os.system(cmd) | |
segments_zip = path.join("task", task_id, f"{speaker}.zip") | |
if not path.exists(segments_zip): | |
with zipfile.ZipFile(segments_zip, "w") as zipf: | |
files = [f for f in os.listdir(segments) if f.endswith(".wav")] | |
for file in files: | |
zipf.write(path.join(segments, file), file) | |
vocals = path.join("task", task_id, "htdemucs", "extracted_48k", "vocals.wav") | |
vocal_segments = path.join("task", task_id, f"{speaker}_vocals") | |
if not path.exists(vocal_segments): | |
os.makedirs(vocal_segments) | |
for i, segment in enumerate(filtered_segments[speaker]): | |
start = segment["start"] | |
end = segment["end"] | |
name = path.join(vocal_segments, f"{i}_{start:.2f}_{end:.2f}.wav") | |
cmd = f"ffmpeg -i {vocals} -ss {start} -to {end} -f wav {name}" | |
os.system(cmd) | |
vocal_segments_zip = path.join("task", task_id, f"{speaker}_vocals.zip") | |
if not path.exists(vocal_segments_zip): | |
with zipfile.ZipFile(vocal_segments_zip, "w") as zipf: | |
files = [f for f in os.listdir(vocal_segments) if f.endswith(".wav")] | |
for file in files: | |
zipf.write(path.join(vocal_segments, file), file) | |
return mp4, segments_zip, vocal_segments_zip | |
with gr.Blocks() as app: | |
gr.Markdown("# Video Speaker Diarization") | |
gr.Markdown( | |
""" | |
First, upload a video file. And let us do some inspection on the audio of the video. | |
""" | |
) | |
original_video = gr.Video(label="Upload a video", show_download_button=True) | |
preprocess_btn = gr.Button(value="Pre Process", variant="primary") | |
preprocess_btn_label = gr.Markdown("Press the button!") | |
task_id = gr.Textbox(label="Task ID", visible=False) | |
with gr.Column(visible=False) as preprocess_output: | |
gr.Markdown( | |
""" | |
Now you can see the spectrogram of the extracted audio. | |
Next, let's remove the background music from the audio. | |
""" | |
) | |
with gr.Row(): | |
extracted_audio = gr.Audio(label="Extracted Audio", type="filepath") | |
extracted_audio_spec = gr.Image(label="Extracted Audio Spectrogram") | |
extract_vocals_btn = gr.Button( | |
value="Remove Background Music", variant="primary" | |
) | |
extract_vocals_btn_label = gr.Markdown("Press the button!") | |
with gr.Column(visible=False) as extract_vocals_output: | |
with gr.Row(): | |
vocals = gr.Audio(label="Vocals", type="filepath") | |
vocals_spec = gr.Image(label="Vocals Spectrogram") | |
diarize_btn = gr.Button(value="Diarize", variant="primary") | |
diarize_btn_label = gr.Markdown("Press the button!") | |
with gr.Column(visible=False) as diarize_output: | |
gr.Markdown( | |
""" | |
Now you can select the speaker from the dropdown below to generate the clips of the speaker. | |
""" | |
) | |
with gr.Row(): | |
speaker_select = gr.Dropdown(label="Speaker", choices=[]) | |
diarization_result = gr.Markdown("", height=400) | |
generate_clips_btn = gr.Button(value="Generate Clips", variant="primary") | |
generate_clips_btn_label = gr.Markdown("Press the button!") | |
with gr.Column(visible=False) as generate_clips_output: | |
speaker_clip = gr.Video(label="Speaker Clip") | |
speaker_clip_zip = gr.File(label="Download Audio Segments") | |
speaker_clip_vocal_zip = gr.File(label="Download Vocal Segments") | |
def preprocess(video: str): | |
task_id_val, extracted_audio_val = extract_audio(video) | |
return { | |
preprocess_output: gr.Column(visible=True), | |
task_id: task_id_val, | |
extracted_audio: extracted_audio_val, | |
preprocess_btn_label: gr.Markdown("", visible=False), | |
} | |
preprocess_btn.click( | |
fn=preprocess, | |
inputs=[original_video], | |
outputs=[ | |
preprocess_output, | |
task_id, | |
extracted_audio, | |
preprocess_btn_label, | |
], | |
api_name="preprocess", | |
).success( | |
fn=extract_audio_post, | |
inputs=[task_id], | |
outputs=[extracted_audio_spec], | |
api_name="preprocess-post", | |
) | |
def extract_vocals_fn(task_id: str): | |
vocals_val = extract_vocals(task_id) | |
return { | |
extract_vocals_output: gr.Column(visible=True), | |
vocals: vocals_val, | |
extract_vocals_btn_label: gr.Markdown("", visible=False), | |
} | |
extract_vocals_btn.click( | |
fn=extract_vocals_fn, | |
inputs=[task_id], | |
outputs=[extract_vocals_output, vocals, extract_vocals_btn_label], | |
api_name="extract-vocals", | |
).success( | |
fn=extract_vocals_post, | |
inputs=[task_id], | |
outputs=[vocals_spec], | |
api_name="extract-vocals-post", | |
) | |
def diarize_fn(task_id: str): | |
filtered_segments = diarize_audio(task_id) | |
choices = [] | |
for speaker in filtered_segments: | |
total = sum(segment["duration"] for segment in filtered_segments[speaker]) | |
choices.append((f"{speaker} ({total:.2f}s)", speaker)) | |
info = "" | |
for speaker, segments in filtered_segments.items(): | |
total = sum(segment["duration"] for segment in segments) | |
info += f"### Speaker {speaker}: ({total:.2f}s)\n" | |
for segment in segments: | |
start = segment["start"] | |
end = segment["end"] | |
info += f"- {start:.2f} - {end:.2f} ({segment['duration']:.2f}s)\n" | |
return { | |
diarize_output: gr.Column(visible=True), | |
speaker_select: gr.Dropdown(label="Speaker", choices=choices), | |
diarization_result: gr.Markdown(info), | |
diarize_btn_label: gr.Markdown("", visible=False), | |
} | |
diarize_btn.click( | |
fn=diarize_fn, | |
inputs=[task_id], | |
outputs=[diarize_output, speaker_select, diarization_result, diarize_btn_label], | |
api_name="diarize", | |
) | |
def generate_clips_fn(task_id: str, speaker: str): | |
speaker_clip_val, zip_val, vocal_zip_val = generate_clips(task_id, speaker) | |
return { | |
generate_clips_output: gr.Column(visible=True), | |
speaker_clip: speaker_clip_val, | |
speaker_clip_zip: zip_val, | |
speaker_clip_vocal_zip: vocal_zip_val, | |
generate_clips_btn_label: gr.Markdown("", visible=False), | |
} | |
generate_clips_btn.click( | |
fn=generate_clips_fn, | |
inputs=[task_id, speaker_select], | |
outputs=[ | |
generate_clips_output, | |
speaker_clip, | |
speaker_clip_zip, | |
speaker_clip_vocal_zip, | |
generate_clips_btn_label, | |
], | |
api_name="generate_clips", | |
) | |
app.launch() | |