|
|
|
|
|
""" |
|
https://huggingface.co/spaces/sayakpaul/demo-docker-gradio |
|
""" |
|
import argparse |
|
import json |
|
import platform |
|
from typing import Tuple |
|
|
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from project_settings import project_path, temp_directory |
|
from toolbox.webrtcvad.vad import WebRTCVad |
|
from toolbox.vad.vad import Vad, WebRTCVoiceClassifier, SileroVoiceClassifier, CallVoiceClassifier, process_speech_probs |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--ring_vad_examples_file", |
|
default=(project_path / "ring_vad_examples.json").as_posix(), |
|
type=str |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
vad: Vad = None |
|
|
|
|
|
def click_ring_vad_button(audio: Tuple[int, np.ndarray], |
|
model_name: str, |
|
agg: int = 3, |
|
frame_length_ms: int = 30, |
|
frame_step_ms: int = 30, |
|
padding_length_ms: int = 300, |
|
max_silence_length_ms: int = 300, |
|
start_ring_rate: float = 0.9, |
|
end_ring_rate: float = 0.1, |
|
max_speech_length_s: float = 2.0, |
|
min_speech_length_s: float = 0.3, |
|
): |
|
global vad |
|
|
|
if audio is None: |
|
return None, "please upload audio." |
|
sample_rate, signal = audio |
|
|
|
if model_name == "webrtcvad" and frame_length_ms not in (10, 20, 30): |
|
return None, "only 10, 20, 30 available for `frame_duration_ms`." |
|
|
|
if model_name == "webrtcvad": |
|
model = WebRTCVoiceClassifier(agg=agg) |
|
elif model_name == "silerovad": |
|
model = SileroVoiceClassifier(model_path=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix()) |
|
elif model_name == "call_voice": |
|
model = CallVoiceClassifier(model_path=(project_path / "trained_models/cnn_voicemail_common_20231130").as_posix()) |
|
else: |
|
return None, "`model_name` not valid." |
|
|
|
vad = Vad(model=model, |
|
start_ring_rate=start_ring_rate, |
|
end_ring_rate=end_ring_rate, |
|
frame_length_ms=frame_length_ms, |
|
frame_step_ms=frame_step_ms, |
|
padding_length_ms=padding_length_ms, |
|
max_silence_length_ms=max_silence_length_ms, |
|
max_speech_length_s=max_speech_length_s, |
|
min_speech_length_s=min_speech_length_s, |
|
sample_rate=sample_rate, |
|
) |
|
|
|
try: |
|
vad_segments = list() |
|
segments = vad.vad(signal) |
|
vad_segments += segments |
|
segments = vad.last_vad_segments() |
|
vad_segments += segments |
|
except Exception as e: |
|
return None, str(e) |
|
|
|
|
|
speech_probs = process_speech_probs( |
|
signal=signal, |
|
speech_probs=vad.speech_probs, |
|
frame_step=vad.frame_step, |
|
) |
|
|
|
time = np.arange(0, len(signal)) / sample_rate |
|
plt.figure(figsize=(12, 5)) |
|
plt.plot(time, signal / 32768, color="b") |
|
plt.plot(time, speech_probs, color="gray") |
|
|
|
plt.axhline(y=start_ring_rate, xmin=0.0, xmax=1.0, color="gray", linestyle="-") |
|
plt.axhline(y=start_ring_rate, xmin=0.0, xmax=frame_length_ms / 1000 / len(signal) * sample_rate, color="red", linestyle="-") |
|
|
|
for start, end in vad_segments: |
|
plt.axvline(x=start, ymin=0.15, ymax=0.85, color="g", linestyle="--") |
|
plt.axvline(x=end, ymin=0.15, ymax=0.85, color="r", linestyle="--") |
|
|
|
temp_image_file = temp_directory / "temp.jpg" |
|
plt.savefig(temp_image_file) |
|
image = Image.open(open(temp_image_file, "rb")) |
|
|
|
return image, vad_segments |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
brief_description = """ |
|
## Voice Activity Detection |
|
|
|
""" |
|
|
|
|
|
with open(args.ring_vad_examples_file, "r", encoding="utf-8") as f: |
|
ring_vad_examples = json.load(f) |
|
|
|
|
|
with gr.Blocks() as blocks: |
|
gr.Markdown(value=brief_description) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
with gr.Tabs(): |
|
with gr.TabItem("ring_vad"): |
|
gr.Markdown(value="") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
ring_wav = gr.Audio(label="wav") |
|
|
|
with gr.Row(): |
|
ring_model_name = gr.Dropdown(choices=["webrtcvad", "silerovad", "call_voice"], value="webrtcvad", label="model_name") |
|
ring_agg = gr.Dropdown(choices=[1, 2, 3], value=3, label="agg") |
|
|
|
with gr.Row(): |
|
ring_frame_length_ms = gr.Slider(minimum=0, maximum=1000, value=30, label="frame_length_ms") |
|
ring_frame_step_ms = gr.Slider(minimum=0, maximum=100, value=30, label="frame_step_ms") |
|
|
|
with gr.Row(): |
|
ring_padding_length_ms = gr.Slider(minimum=0, maximum=1000, value=300, label="padding_length_ms") |
|
ring_max_silence_length_ms = gr.Slider(minimum=0, maximum=1000, value=300, step=0.1, label="max_silence_length_ms") |
|
|
|
with gr.Row(): |
|
ring_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.05, label="start_ring_rate") |
|
ring_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="end_ring_rate") |
|
|
|
with gr.Row(): |
|
ring_max_speech_length_s = gr.Slider(minimum=0.0, maximum=10.0, value=2.0, step=0.05, label="max_speech_length_s") |
|
ring_min_speech_length_s = gr.Slider(minimum=0.0, maximum=2.0, value=0.3, step=0.05, label="min_speech_length_s") |
|
|
|
ring_button = gr.Button("run", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
ring_image = gr.Image(label="image", height=300, width=720, show_label=False) |
|
ring_end_points = gr.TextArea(label="end_points", max_lines=35) |
|
|
|
gr.Examples( |
|
examples=ring_vad_examples, |
|
inputs=[ |
|
ring_wav, |
|
ring_model_name, ring_agg, |
|
ring_frame_length_ms, ring_frame_step_ms, |
|
ring_padding_length_ms, ring_max_silence_length_ms, |
|
ring_start_ring_rate, ring_end_ring_rate, |
|
ring_max_speech_length_s, ring_min_speech_length_s |
|
], |
|
outputs=[ring_image, ring_end_points], |
|
fn=click_ring_vad_button |
|
) |
|
|
|
|
|
ring_button.click( |
|
click_ring_vad_button, |
|
inputs=[ |
|
ring_wav, |
|
ring_model_name, ring_agg, |
|
ring_frame_length_ms, ring_frame_step_ms, |
|
ring_padding_length_ms, ring_max_silence_length_ms, |
|
ring_start_ring_rate, ring_end_ring_rate, |
|
ring_max_speech_length_s, ring_min_speech_length_s |
|
], |
|
outputs=[ring_image, ring_end_points], |
|
) |
|
|
|
blocks.queue().launch( |
|
share=False if platform.system() == "Windows" else False, |
|
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", |
|
server_port=7860 |
|
) |
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|