qgyd2021's picture
update
4fcb518
raw
history blame
8.06 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://huggingface.co/spaces/sayakpaul/demo-docker-gradio
"""
import argparse
import json
import platform
from typing import Tuple
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from project_settings import project_path, temp_directory
from toolbox.webrtcvad.vad import WebRTCVad
from toolbox.vad.vad import Vad, WebRTCVoiceClassifier, SileroVoiceClassifier, CallVoiceClassifier, process_speech_probs
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--ring_vad_examples_file",
default=(project_path / "ring_vad_examples.json").as_posix(),
type=str
)
args = parser.parse_args()
return args
vad: Vad = None
def click_ring_vad_button(audio: Tuple[int, np.ndarray],
model_name: str,
agg: int = 3,
frame_length_ms: int = 30,
frame_step_ms: int = 30,
padding_length_ms: int = 300,
max_silence_length_ms: int = 300,
start_ring_rate: float = 0.9,
end_ring_rate: float = 0.1,
max_speech_length_s: float = 2.0,
min_speech_length_s: float = 0.3,
):
global vad
if audio is None:
return None, "please upload audio."
sample_rate, signal = audio
if model_name == "webrtcvad" and frame_length_ms not in (10, 20, 30):
return None, "only 10, 20, 30 available for `frame_duration_ms`."
if model_name == "webrtcvad":
model = WebRTCVoiceClassifier(agg=agg)
elif model_name == "silerovad":
model = SileroVoiceClassifier(model_path=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix())
elif model_name == "call_voice":
model = CallVoiceClassifier(model_path=(project_path / "trained_models/cnn_voicemail_common_20231130").as_posix())
else:
return None, "`model_name` not valid."
vad = Vad(model=model,
start_ring_rate=start_ring_rate,
end_ring_rate=end_ring_rate,
frame_length_ms=frame_length_ms,
frame_step_ms=frame_step_ms,
padding_length_ms=padding_length_ms,
max_silence_length_ms=max_silence_length_ms,
max_speech_length_s=max_speech_length_s,
min_speech_length_s=min_speech_length_s,
sample_rate=sample_rate,
)
try:
vad_segments = list()
segments = vad.vad(signal)
vad_segments += segments
segments = vad.last_vad_segments()
vad_segments += segments
except Exception as e:
return None, str(e)
# speech_probs
speech_probs = process_speech_probs(
signal=signal,
speech_probs=vad.speech_probs,
frame_step=vad.frame_step,
)
time = np.arange(0, len(signal)) / sample_rate
plt.figure(figsize=(12, 5))
plt.plot(time, signal / 32768, color="b")
plt.plot(time, speech_probs, color="gray")
plt.axhline(y=start_ring_rate, xmin=0.0, xmax=1.0, color="gray", linestyle="-")
plt.axhline(y=start_ring_rate, xmin=0.0, xmax=frame_length_ms / 1000 / len(signal) * sample_rate, color="red", linestyle="-")
for start, end in vad_segments:
plt.axvline(x=start, ymin=0.15, ymax=0.85, color="g", linestyle="--")
plt.axvline(x=end, ymin=0.15, ymax=0.85, color="r", linestyle="--")
temp_image_file = temp_directory / "temp.jpg"
plt.savefig(temp_image_file)
image = Image.open(open(temp_image_file, "rb"))
return image, vad_segments
def main():
args = get_args()
brief_description = """
## Voice Activity Detection
"""
# examples
with open(args.ring_vad_examples_file, "r", encoding="utf-8") as f:
ring_vad_examples = json.load(f)
# ui
with gr.Blocks() as blocks:
gr.Markdown(value=brief_description)
with gr.Row():
with gr.Column(scale=5):
with gr.Tabs():
with gr.TabItem("ring_vad"):
gr.Markdown(value="")
with gr.Row():
with gr.Column(scale=1):
ring_wav = gr.Audio(label="wav")
with gr.Row():
ring_model_name = gr.Dropdown(choices=["webrtcvad", "silerovad", "call_voice"], value="webrtcvad", label="model_name")
ring_agg = gr.Dropdown(choices=[1, 2, 3], value=3, label="agg")
with gr.Row():
ring_frame_length_ms = gr.Slider(minimum=0, maximum=1000, value=30, label="frame_length_ms")
ring_frame_step_ms = gr.Slider(minimum=0, maximum=100, value=30, label="frame_step_ms")
with gr.Row():
ring_padding_length_ms = gr.Slider(minimum=0, maximum=1000, value=300, label="padding_length_ms")
ring_max_silence_length_ms = gr.Slider(minimum=0, maximum=1000, value=300, step=0.1, label="max_silence_length_ms")
with gr.Row():
ring_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.05, label="start_ring_rate")
ring_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="end_ring_rate")
with gr.Row():
ring_max_speech_length_s = gr.Slider(minimum=0.0, maximum=10.0, value=2.0, step=0.05, label="max_speech_length_s")
ring_min_speech_length_s = gr.Slider(minimum=0.0, maximum=2.0, value=0.3, step=0.05, label="min_speech_length_s")
ring_button = gr.Button("run", variant="primary")
with gr.Column(scale=1):
ring_image = gr.Image(label="image", height=300, width=720, show_label=False)
ring_end_points = gr.TextArea(label="end_points", max_lines=35)
gr.Examples(
examples=ring_vad_examples,
inputs=[
ring_wav,
ring_model_name, ring_agg,
ring_frame_length_ms, ring_frame_step_ms,
ring_padding_length_ms, ring_max_silence_length_ms,
ring_start_ring_rate, ring_end_ring_rate,
ring_max_speech_length_s, ring_min_speech_length_s
],
outputs=[ring_image, ring_end_points],
fn=click_ring_vad_button
)
# click event
ring_button.click(
click_ring_vad_button,
inputs=[
ring_wav,
ring_model_name, ring_agg,
ring_frame_length_ms, ring_frame_step_ms,
ring_padding_length_ms, ring_max_silence_length_ms,
ring_start_ring_rate, ring_end_ring_rate,
ring_max_speech_length_s, ring_min_speech_length_s
],
outputs=[ring_image, ring_end_points],
)
blocks.queue().launch(
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=7860
)
return
if __name__ == "__main__":
main()