kkvc
up
5049dc3
import argparse
import os
import shutil
from pathlib import Path
import soundfile as sf
import torch
from tqdm import tqdm
from common.log import logger
from common.stdout_wrapper import SAFE_STDOUT
vad_model, utils = torch.hub.load(
repo_or_dir="snakers4/silero-vad",
model="silero_vad",
onnx=True,
trust_repo=True,
)
(get_speech_timestamps, _, read_audio, *_) = utils
def get_stamps(
audio_file, min_silence_dur_ms: int = 700, min_sec: float = 2, max_sec: float = 12
):
"""
min_silence_dur_ms: int (ミリ秒):
このミリ秒数以上を無音だと判断する。
逆に、この秒数以下の無音区間では区切られない。
小さくすると、音声がぶつ切りに小さくなりすぎ、
大きくすると音声一つ一つが長くなりすぎる。
データセットによってたぶん要調整。
min_sec: float (秒):
この秒数より小さい発話は無視する。
max_sec: float (秒):
この秒数より大きい発話は無視する。
"""
sampling_rate = 16000 # 16kHzか8kHzのみ対応
min_ms = int(min_sec * 1000)
wav = read_audio(audio_file, sampling_rate=sampling_rate)
speech_timestamps = get_speech_timestamps(
wav,
vad_model,
sampling_rate=sampling_rate,
min_silence_duration_ms=min_silence_dur_ms,
min_speech_duration_ms=min_ms,
max_speech_duration_s=max_sec,
)
return speech_timestamps
def split_wav(
audio_file,
target_dir="raw",
min_sec=2,
max_sec=12,
min_silence_dur_ms=700,
):
margin = 200 # ミリ秒単位で、音声の前後に余裕を持たせる
speech_timestamps = get_stamps(
audio_file,
min_silence_dur_ms=min_silence_dur_ms,
min_sec=min_sec,
max_sec=max_sec,
)
data, sr = sf.read(audio_file)
total_ms = len(data) / sr * 1000
file_name = os.path.basename(audio_file).split(".")[0]
os.makedirs(target_dir, exist_ok=True)
total_time_ms = 0
# タイムスタンプに従って分割し、ファイルに保存
for i, ts in enumerate(speech_timestamps):
start_ms = max(ts["start"] / 16 - margin, 0)
end_ms = min(ts["end"] / 16 + margin, total_ms)
start_sample = int(start_ms / 1000 * sr)
end_sample = int(end_ms / 1000 * sr)
segment = data[start_sample:end_sample]
sf.write(os.path.join(target_dir, f"{file_name}-{i}.wav"), segment, sr)
total_time_ms += end_ms - start_ms
return total_time_ms / 1000
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--min_sec", "-m", type=float, default=2, help="Minimum seconds of a slice"
)
parser.add_argument(
"--max_sec", "-M", type=float, default=12, help="Maximum seconds of a slice"
)
parser.add_argument(
"--input_dir",
"-i",
type=str,
default="inputs",
help="Directory of input wav files",
)
parser.add_argument(
"--output_dir",
"-o",
type=str,
default="raw",
help="Directory of output wav files",
)
parser.add_argument(
"--min_silence_dur_ms",
"-s",
type=int,
default=700,
help="Silence above this duration (ms) is considered as a split point.",
)
args = parser.parse_args()
input_dir = args.input_dir
output_dir = args.output_dir
min_sec = args.min_sec
max_sec = args.max_sec
min_silence_dur_ms = args.min_silence_dur_ms
wav_files = Path(input_dir).glob("**/*.wav")
wav_files = list(wav_files)
logger.info(f"Found {len(wav_files)} wav files.")
if os.path.exists(output_dir):
logger.warning(f"Output directory {output_dir} already exists, deleting...")
shutil.rmtree(output_dir)
total_sec = 0
for wav_file in tqdm(wav_files, file=SAFE_STDOUT):
time_sec = split_wav(
audio_file=str(wav_file),
target_dir=output_dir,
min_sec=min_sec,
max_sec=max_sec,
min_silence_dur_ms=min_silence_dur_ms,
)
total_sec += time_sec
logger.info(f"Slice done! Total time: {total_sec / 60:.2f} min.")