qgyd2021's picture
update
7e17176
raw
history blame
6.51 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://huggingface.co/spaces/sayakpaul/demo-docker-gradio
"""
import argparse
import json
import platform
from typing import Tuple
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from project_settings import project_path, temp_directory
from toolbox.webrtcvad.vad import WebRTCVad
from toolbox.vad.vad import Vad, WebRTCVoiceClassifier, SileroVoiceClassifier
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--ring_vad_examples_file",
default=(project_path / "ring_vad_examples.json").as_posix(),
type=str
)
args = parser.parse_args()
return args
vad: Vad = None
def click_ring_vad_button(audio: Tuple[int, np.ndarray],
model_name: str,
agg: int = 3,
frame_duration_ms: int = 30,
padding_duration_ms: int = 300,
silence_duration_threshold: float = 0.3,
start_ring_rate: float = 0.9,
end_ring_rate: float = 0.1,
):
global vad
if audio is None:
return None, "please upload audio."
sample_rate, signal = audio
if model_name == "webrtcvad" and frame_duration_ms not in (10, 20, 30):
return None, "only 10, 20, 30 available for `frame_duration_ms`."
if model_name == "webrtcvad":
model = WebRTCVoiceClassifier(agg=agg)
elif model_name == "silerovad":
model = SileroVoiceClassifier(model_name=(project_path / "pretrained_models/silero_vad/silero_vad.jit").as_posix())
else:
return None, "`model_name` not valid."
vad = Vad(model=model,
start_ring_rate=start_ring_rate,
end_ring_rate=end_ring_rate,
frame_duration_ms=frame_duration_ms,
padding_duration_ms=padding_duration_ms,
silence_duration_threshold=silence_duration_threshold,
sample_rate=sample_rate,
)
try:
vad_segments = list()
segments = vad.vad(signal)
vad_segments += segments
segments = vad.last_vad_segments()
vad_segments += segments
except Exception as e:
return None, str(e)
time = np.arange(0, len(signal)) / sample_rate
plt.figure(figsize=(12, 5))
plt.plot(time, signal / 32768, color='b')
for start, end in vad_segments:
plt.axvline(x=start, ymin=0.25, ymax=0.75, color='g', linestyle='--', label='开始端点') # 标记开始端点
plt.axvline(x=end, ymin=0.25, ymax=0.75, color='r', linestyle='--', label='结束端点') # 标记结束端点
temp_image_file = temp_directory / "temp.jpg"
plt.savefig(temp_image_file)
image = Image.open(open(temp_image_file, "rb"))
return image, vad_segments
def main():
args = get_args()
brief_description = """
## Voice Activity Detection
"""
# examples
with open(args.ring_vad_examples_file, "r", encoding="utf-8") as f:
ring_vad_examples = json.load(f)
# ui
with gr.Blocks() as blocks:
gr.Markdown(value=brief_description)
with gr.Row():
with gr.Column(scale=5):
with gr.Tabs():
with gr.TabItem("ring_vad"):
gr.Markdown(value="")
with gr.Row():
with gr.Column(scale=1):
ring_wav = gr.Audio(label="wav")
with gr.Row():
ring_model_name = gr.Dropdown(choices=["webrtcvad", "silerovad"], value="webrtcvad", label="model_name")
with gr.Row():
ring_agg = gr.Dropdown(choices=[1, 2, 3], value=3, label="agg")
ring_frame_duration_ms = gr.Slider(minimum=0, maximum=100, value=30, label="frame_duration_ms")
with gr.Row():
ring_padding_duration_ms = gr.Slider(minimum=0, maximum=1000, value=300, label="padding_duration_ms")
ring_silence_duration_threshold = gr.Slider(minimum=0, maximum=1.0, value=0.3, step=0.1, label="silence_duration_threshold")
with gr.Row():
ring_start_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="start_ring_rate")
ring_end_ring_rate = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.1, label="end_ring_rate")
ring_button = gr.Button("retrieval", variant="primary")
with gr.Column(scale=1):
ring_image = gr.Image(label="image", height=300, width=720, show_label=False)
ring_end_points = gr.TextArea(label="end_points", max_lines=35)
gr.Examples(
examples=ring_vad_examples,
inputs=[
ring_wav,
ring_model_name, ring_agg, ring_frame_duration_ms,
ring_padding_duration_ms, ring_silence_duration_threshold,
ring_start_ring_rate, ring_end_ring_rate
],
outputs=[ring_image, ring_end_points],
fn=click_ring_vad_button
)
# click event
ring_button.click(
click_ring_vad_button,
inputs=[
ring_wav,
ring_model_name, ring_agg, ring_frame_duration_ms,
ring_padding_duration_ms, ring_silence_duration_threshold,
ring_start_ring_rate, ring_end_ring_rate
],
outputs=[ring_image, ring_end_points],
)
blocks.queue().launch(
share=False if platform.system() == "Windows" else False,
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
server_port=7860
)
return
if __name__ == "__main__":
main()