JacobLinCool commited on
Commit
7c6792a
1 Parent(s): 2029ea8

feat: gradio app

Browse files
Files changed (5) hide show
  1. .gitignore +5 -0
  2. app.py +334 -0
  3. requirements.txt +10 -0
  4. utils.py +18 -0
  5. zero.py +21 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .DS_Store
2
+ __pycache__/
3
+ *.pyc
4
+
5
+ task/
app.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from matplotlib import pyplot as plt
2
+ from accelerate import Accelerator
3
+ from zero import zero
4
+ import gradio as gr
5
+ from typing import Tuple
6
+ import os
7
+ from os import path
8
+ from utils import plot_spec
9
+ import librosa
10
+ from hashlib import md5
11
+ from demucs.separate import main as demucs
12
+ from pyannote.audio import Pipeline
13
+ from json import dumps, loads
14
+ import shutil
15
+
16
+ accelerator = Accelerator()
17
+ device = accelerator.device
18
+ print(f"Running on {device}")
19
+
20
+ pipeline = Pipeline.from_pretrained(
21
+ "pyannote/speaker-diarization-3.1", use_auth_token=os.environ["HF_TOKEN"]
22
+ )
23
+ pipeline.to(device)
24
+
25
+
26
+ tasks = []
27
+ os.makedirs("task", exist_ok=True)
28
+ for task in os.listdir("task"):
29
+ if path.isdir(path.join("task", task)):
30
+ tasks.append(task)
31
+
32
+
33
+ def gen_task_id(location: str):
34
+ # use md5 hash of video file as task id
35
+ video = open(location, "rb").read()
36
+ return md5(video).hexdigest()
37
+
38
+
39
+ def extract_audio(video: str) -> Tuple[str, str, str]:
40
+ task_id = gen_task_id(video)
41
+ os.makedirs(path.join("task", task_id), exist_ok=True)
42
+
43
+ videodest = path.join("task", task_id, "video.mp4")
44
+ if not path.exists(videodest):
45
+ shutil.copy(video, videodest)
46
+
47
+ wav48k = path.join("task", task_id, "extracted_48k.wav")
48
+ if not path.exists(wav48k):
49
+ os.system(
50
+ f"ffmpeg -i {videodest} -vn -ar 48000 task/{task_id}/extracted_48k.wav"
51
+ )
52
+
53
+ spec = path.join("task", task_id, "extracted_48k.png")
54
+ if not path.exists(spec):
55
+ y, sr = librosa.load(wav48k, sr=16000)
56
+ fig = plot_spec(y, sr)
57
+ fig.savefig(path.join("task", task_id, "extracted_48k.png"))
58
+ plt.close(fig)
59
+
60
+ return (task_id, wav48k, spec)
61
+
62
+
63
+ @zero()
64
+ def extract_vocals(task_id: str) -> Tuple[str, str]:
65
+ audio = path.join("task", task_id, "extracted_48k.wav")
66
+ if not path.exists(audio):
67
+ raise gr.Error("Audio file not found")
68
+
69
+ vocals = path.join("task", task_id, "htdemucs", "extracted_48k", "vocals.wav")
70
+
71
+ if not path.exists(vocals):
72
+ demucs(
73
+ [
74
+ "-d",
75
+ str(device),
76
+ "-n",
77
+ "htdemucs",
78
+ "--two-stems",
79
+ "vocals",
80
+ "-o",
81
+ path.join("task", task_id),
82
+ audio,
83
+ ]
84
+ )
85
+
86
+ spec = path.join("task", task_id, "vocals.png")
87
+ if not path.exists(spec):
88
+ y, sr = librosa.load(vocals, sr=16000)
89
+ fig = plot_spec(y, sr)
90
+ fig.savefig(path.join("task", task_id, "vocals.png"))
91
+ plt.close(fig)
92
+
93
+ return (vocals, spec)
94
+
95
+
96
+ @zero(duration=60 * 2)
97
+ def diarize_audio(task_id: str):
98
+ vocals = path.join("task", task_id, "htdemucs", "extracted_48k", "vocals.wav")
99
+ if not path.exists(vocals):
100
+ raise gr.Error("Vocals file not found")
101
+
102
+ diarization_json = path.join("task", task_id, "diarization.json")
103
+ if not path.exists(diarization_json):
104
+ result = pipeline(vocals)
105
+ with open(diarization_json, "w") as f:
106
+ diarization = []
107
+ for turn, _, speaker in result.itertracks(yield_label=True):
108
+ diarization.append(
109
+ {
110
+ "speaker": speaker,
111
+ "start": turn.start,
112
+ "end": turn.end,
113
+ "duration": turn.duration,
114
+ }
115
+ )
116
+ f.write(dumps(diarization))
117
+ with open(diarization_json, "r") as f:
118
+ diarization = loads(f.read())
119
+
120
+ filtered_json = path.join("task", task_id, "filtered_diarization.json")
121
+ if not path.exists(filtered_json):
122
+ # Filter out segments shorter than 2 second and group by speaker
123
+ filtered_segments = {}
124
+ for turn in diarization:
125
+ speaker = turn["speaker"]
126
+ if turn["duration"] >= 2.0:
127
+ if speaker not in filtered_segments:
128
+ filtered_segments[speaker] = []
129
+ filtered_segments[speaker].append(turn)
130
+
131
+ # Filter out speakers with less than 60 seconds of speech
132
+ filtered_segments = {
133
+ speaker: segments
134
+ for speaker, segments in filtered_segments.items()
135
+ if sum(segment["duration"] for segment in segments) >= 60
136
+ }
137
+
138
+ with open(filtered_json, "w") as f:
139
+ f.write(dumps(filtered_segments))
140
+ with open(filtered_json, "r") as f:
141
+ filtered_segments = loads(f.read())
142
+
143
+ return filtered_segments
144
+
145
+
146
+ def generate_clips(task_id: str, speaker: str) -> Tuple[str, str]:
147
+ video = path.join("task", task_id, "video.mp4")
148
+ if not path.exists(video):
149
+ raise gr.Error("Video file not found")
150
+
151
+ filtered_json = path.join("task", task_id, "filtered_diarization.json")
152
+ if not path.exists(filtered_json):
153
+ raise gr.Error("Diarization not found")
154
+
155
+ with open(filtered_json, "r") as f:
156
+ filtered_segments = loads(f.read())
157
+
158
+ if speaker not in filtered_segments:
159
+ raise gr.Error("Speaker not found")
160
+
161
+ mp4 = path.join("task", task_id, f"{speaker}.mp4")
162
+ if not path.exists(mp4):
163
+ cmd = f'ffmpeg -i {video} -filter_complex "'
164
+ for i, segment in enumerate(filtered_segments[speaker]):
165
+ start = segment["start"]
166
+ end = segment["end"]
167
+ cmd += f"[0:v]trim=start={start}:end={end},setpts=PTS-STARTPTS[v{i}];"
168
+ cmd += f"[0:a]atrim=start={start}:end={end},asetpts=PTS-STARTPTS[a{i}];"
169
+ for i in range(len(filtered_segments[speaker])):
170
+ cmd += f"[v{i}][a{i}]"
171
+ cmd += f'concat=n={len(filtered_segments[speaker])}:v=1:a=1[outv][outa]" -map [outv] -map [outa] -y {mp4}'
172
+ os.system(cmd)
173
+
174
+ segments = path.join("task", task_id, f"{speaker}")
175
+ if not path.exists(segments):
176
+ os.makedirs(segments)
177
+ for i, segment in enumerate(filtered_segments[speaker]):
178
+ start = segment["start"]
179
+ end = segment["end"]
180
+ name = path.join(segments, f"{i}_{start:.2f}_{end:.2f}.wav")
181
+ cmd = f"ffmpeg -i {video} -ss {start} -to {end} -f wav {name}"
182
+ os.system(cmd)
183
+
184
+ segments_zip = path.join("task", task_id, f"{speaker}.zip")
185
+ if not path.exists(segments_zip):
186
+ os.system(f"zip -r {segments_zip} {segments}")
187
+
188
+ return mp4, segments_zip
189
+
190
+
191
+ with gr.Blocks() as app:
192
+ gr.Markdown("# Video Speaker Diarization")
193
+
194
+ gr.Markdown(
195
+ """
196
+ First, upload a video file. And let us do some inspection on the audio of the video.
197
+ """
198
+ )
199
+ original_video = gr.Video(label="Upload a video", show_download_button=True)
200
+ preprocess_btn = gr.Button(value="Pre Process", variant="primary")
201
+ preprocess_btn_label = gr.Markdown("Press the button!")
202
+
203
+ with gr.Column(visible=False) as preprocess_output:
204
+ gr.Markdown(
205
+ """
206
+ Now you can see the spectrogram of the extracted audio.
207
+
208
+ Next, let's remove the background music from the audio.
209
+ """
210
+ )
211
+ task_id = gr.Textbox(label="Task ID", visible=False)
212
+ extracted_audio = gr.Audio(label="Extracted Audio", type="filepath")
213
+ extracted_audio_spec = gr.Image(label="Extracted Audio Spectrogram")
214
+
215
+ extract_vocals_btn = gr.Button(
216
+ value="Remove Background Music", variant="primary"
217
+ )
218
+ extract_vocals_btn_label = gr.Markdown("Press the button!")
219
+
220
+ with gr.Column(visible=False) as extract_vocals_output:
221
+ vocals = gr.Audio(label="Vocals", type="filepath")
222
+ vocals_spec = gr.Image(label="Vocals Spectrogram")
223
+
224
+ diarize_btn = gr.Button(value="Diarize", variant="primary")
225
+ diarize_btn_label = gr.Markdown("Press the button!")
226
+
227
+ with gr.Column(visible=False) as diarize_output:
228
+ gr.Markdown(
229
+ """
230
+ Now you can select the speaker from the dropdown below to generate the clips of the speaker.
231
+ """
232
+ )
233
+ speaker_select = gr.Dropdown(label="Speaker", choices=[])
234
+ diarization_result = gr.Markdown("")
235
+
236
+ generate_clips_btn = gr.Button(value="Generate Clips", variant="primary")
237
+ generate_clips_btn_label = gr.Markdown("Press the button!")
238
+
239
+ with gr.Column(visible=False) as generate_clips_output:
240
+ speaker_clip = gr.Video(label="Speaker Clip")
241
+ speaker_clip_zip = gr.File(label="Download Audio Segments")
242
+
243
+ def preprocess(video: str):
244
+ task_id_val, extracted_audio_val, extracted_audio_spec_val = extract_audio(
245
+ video
246
+ )
247
+ return {
248
+ preprocess_output: gr.Column(visible=True),
249
+ task_id: task_id_val,
250
+ extracted_audio: extracted_audio_val,
251
+ extracted_audio_spec: extracted_audio_spec_val,
252
+ preprocess_btn_label: gr.Markdown("", visible=False),
253
+ }
254
+
255
+ preprocess_btn.click(
256
+ fn=preprocess,
257
+ inputs=[original_video],
258
+ outputs=[
259
+ preprocess_output,
260
+ task_id,
261
+ extracted_audio,
262
+ extracted_audio_spec,
263
+ preprocess_btn_label,
264
+ ],
265
+ api_name="preprocess",
266
+ )
267
+
268
+ def extract_vocals_fn(task_id: str):
269
+ vocals_val, vocals_spec_val = extract_vocals(task_id)
270
+ return {
271
+ extract_vocals_output: gr.Column(visible=True),
272
+ vocals: vocals_val,
273
+ vocals_spec: vocals_spec_val,
274
+ extract_vocals_btn_label: gr.Markdown("", visible=False),
275
+ }
276
+
277
+ extract_vocals_btn.click(
278
+ fn=extract_vocals_fn,
279
+ inputs=[task_id],
280
+ outputs=[extract_vocals_output, vocals, vocals_spec, extract_vocals_btn_label],
281
+ api_name="extract_vocals",
282
+ )
283
+
284
+ def diarize_fn(task_id: str):
285
+ filtered_segments = diarize_audio(task_id)
286
+ choices = []
287
+ for speaker in filtered_segments:
288
+ total = sum(segment["duration"] for segment in filtered_segments[speaker])
289
+ choices.append((f"{speaker} ({total:.2f}s)", speaker))
290
+
291
+ info = ""
292
+ for speaker, segments in filtered_segments.items():
293
+ total = sum(segment["duration"] for segment in segments)
294
+ info += f"### Speaker {speaker}: ({total:.2f}s)\n"
295
+ for segment in segments:
296
+ start = segment["start"]
297
+ end = segment["end"]
298
+ info += f"- {start:.2f} - {end:.2f} ({segment['duration']:.2f}s)\n"
299
+ return {
300
+ diarize_output: gr.Column(visible=True),
301
+ speaker_select: gr.Dropdown(label="Speaker", choices=choices),
302
+ diarization_result: gr.Markdown(info),
303
+ diarize_btn_label: gr.Markdown("", visible=False),
304
+ }
305
+
306
+ diarize_btn.click(
307
+ fn=diarize_fn,
308
+ inputs=[task_id],
309
+ outputs=[diarize_output, speaker_select, diarization_result, diarize_btn_label],
310
+ api_name="diarize",
311
+ )
312
+
313
+ def generate_clips_fn(task_id: str, speaker: str):
314
+ speaker_clip_val, zip_val = generate_clips(task_id, speaker)
315
+ return {
316
+ generate_clips_output: gr.Column(visible=True),
317
+ speaker_clip: speaker_clip_val,
318
+ speaker_clip_zip: zip_val,
319
+ generate_clips_btn_label: gr.Markdown("", visible=False),
320
+ }
321
+
322
+ generate_clips_btn.click(
323
+ fn=generate_clips_fn,
324
+ inputs=[task_id, speaker_select],
325
+ outputs=[
326
+ generate_clips_output,
327
+ speaker_clip,
328
+ speaker_clip_zip,
329
+ generate_clips_btn_label,
330
+ ],
331
+ api_name="generate_clips",
332
+ )
333
+
334
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.2.0
2
+ soundfile==0.12.1
3
+ numpy==1.26.0
4
+ librosa==0.9.2
5
+ einops==0.8.0
6
+ gradio==4.37.2
7
+ accelerate==0.31.0
8
+ matplotlib==3.8.3
9
+ demucs==4.0.1
10
+ pyannote-audio==3.3.1
utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import librosa
3
+ import matplotlib.pyplot as plt
4
+ from librosa.display import specshow
5
+
6
+
7
+ def plot_spec(y: np.ndarray, sr: int, title: str = "Spectrogram") -> plt.Figure:
8
+ y[np.isnan(y)] = 0
9
+ y[np.isinf(y)] = 0
10
+ stft = librosa.stft(y=y)
11
+ D = librosa.amplitude_to_db(np.abs(stft), ref=np.max)
12
+
13
+ fig = plt.figure(figsize=(10, 4))
14
+ specshow(D, sr=sr, y_axis="linear", x_axis="time", cmap="viridis")
15
+ plt.title(title)
16
+ plt.tight_layout()
17
+
18
+ return fig
zero.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ zero_is_available = "SPACES_ZERO_GPU" in os.environ
4
+
5
+ if zero_is_available:
6
+ import spaces # type: ignore
7
+
8
+ print("ZeroGPU is available")
9
+ else:
10
+ print("ZeroGPU is not available")
11
+
12
+
13
+ # a decorator that applies the spaces.GPU decorator if zero is available
14
+ def zero(duration=60):
15
+ def wrapper(func):
16
+ if zero_is_available:
17
+ return spaces.GPU(func, duration=duration)
18
+ else:
19
+ return func
20
+
21
+ return wrapper