Spaces:
Runtime error
Runtime error
OFA-OCR-dedao-demo001
/
fairseq
/examples
/speech_synthesis
/preprocessing
/denoise_and_vad_audio.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import logging | |
import os | |
import csv | |
import tempfile | |
from collections import defaultdict | |
from pathlib import Path | |
import torchaudio | |
try: | |
import webrtcvad | |
except ImportError: | |
raise ImportError("Please install py-webrtcvad: pip install webrtcvad") | |
import pandas as pd | |
from tqdm import tqdm | |
from examples.speech_synthesis.preprocessing.denoiser.pretrained import master64 | |
import examples.speech_synthesis.preprocessing.denoiser.utils as utils | |
from examples.speech_synthesis.preprocessing.vad import ( | |
frame_generator, vad_collector, read_wave, write_wave, FS_MS, THRESHOLD, | |
SCALE | |
) | |
from examples.speech_to_text.data_utils import save_df_to_tsv | |
log = logging.getLogger(__name__) | |
PATHS = ["after_denoise", "after_vad"] | |
MIN_T = 0.05 | |
def generate_tmp_filename(extension="txt"): | |
return tempfile._get_default_tempdir() + "/" + \ | |
next(tempfile._get_candidate_names()) + "." + extension | |
def convert_sr(inpath, sr, output_path=None): | |
if not output_path: | |
output_path = generate_tmp_filename("wav") | |
cmd = f"sox {inpath} -r {sr} {output_path}" | |
os.system(cmd) | |
return output_path | |
def apply_vad(vad, inpath): | |
audio, sample_rate = read_wave(inpath) | |
frames = frame_generator(FS_MS, audio, sample_rate) | |
frames = list(frames) | |
segments = vad_collector(sample_rate, FS_MS, 300, vad, frames) | |
merge_segments = list() | |
timestamp_start = 0.0 | |
timestamp_end = 0.0 | |
# removing start, end, and long sequences of sils | |
for i, segment in enumerate(segments): | |
merge_segments.append(segment[0]) | |
if i and timestamp_start: | |
sil_duration = segment[1] - timestamp_end | |
if sil_duration > THRESHOLD: | |
merge_segments.append(int(THRESHOLD / SCALE) * (b'\x00')) | |
else: | |
merge_segments.append(int((sil_duration / SCALE)) * (b'\x00')) | |
timestamp_start = segment[1] | |
timestamp_end = segment[2] | |
segment = b''.join(merge_segments) | |
return segment, sample_rate | |
def write(wav, filename, sr=16_000): | |
# Normalize audio if it prevents clipping | |
wav = wav / max(wav.abs().max().item(), 1) | |
torchaudio.save(filename, wav.cpu(), sr, encoding="PCM_S", | |
bits_per_sample=16) | |
def process(args): | |
# making sure we are requested either denoise or vad | |
if not args.denoise and not args.vad: | |
log.error("No denoise or vad is requested.") | |
return | |
log.info("Creating out directories...") | |
if args.denoise: | |
out_denoise = Path(args.output_dir).absolute().joinpath(PATHS[0]) | |
out_denoise.mkdir(parents=True, exist_ok=True) | |
if args.vad: | |
out_vad = Path(args.output_dir).absolute().joinpath(PATHS[1]) | |
out_vad.mkdir(parents=True, exist_ok=True) | |
log.info("Loading pre-trained speech enhancement model...") | |
model = master64().to(args.device) | |
log.info("Building the VAD model...") | |
vad = webrtcvad.Vad(int(args.vad_agg_level)) | |
# preparing the output dict | |
output_dict = defaultdict(list) | |
log.info(f"Parsing input manifest: {args.audio_manifest}") | |
with open(args.audio_manifest, "r") as f: | |
manifest_dict = csv.DictReader(f, delimiter="\t") | |
for row in tqdm(manifest_dict): | |
filename = str(row["audio"]) | |
final_output = filename | |
keep_sample = True | |
n_frames = row["n_frames"] | |
snr = -1 | |
if args.denoise: | |
output_path_denoise = out_denoise.joinpath(Path(filename).name) | |
# convert to 16khz in case we use a differet sr | |
tmp_path = convert_sr(final_output, 16000) | |
# loading audio file and generating the enhanced version | |
out, sr = torchaudio.load(tmp_path) | |
out = out.to(args.device) | |
estimate = model(out) | |
estimate = (1 - args.dry_wet) * estimate + args.dry_wet * out | |
write(estimate[0], str(output_path_denoise), sr) | |
snr = utils.cal_snr(out, estimate) | |
snr = snr.cpu().detach().numpy()[0][0] | |
final_output = str(output_path_denoise) | |
if args.vad: | |
output_path_vad = out_vad.joinpath(Path(filename).name) | |
sr = torchaudio.info(final_output).sample_rate | |
if sr in [16000, 32000, 48000]: | |
tmp_path = final_output | |
elif sr < 16000: | |
tmp_path = convert_sr(final_output, 16000) | |
elif sr < 32000: | |
tmp_path = convert_sr(final_output, 32000) | |
else: | |
tmp_path = convert_sr(final_output, 48000) | |
# apply VAD | |
segment, sample_rate = apply_vad(vad, tmp_path) | |
if len(segment) < sample_rate * MIN_T: | |
keep_sample = False | |
print(( | |
f"WARNING: skip {filename} because it is too short " | |
f"after VAD ({len(segment) / sample_rate} < {MIN_T})" | |
)) | |
else: | |
if sample_rate != sr: | |
tmp_path = generate_tmp_filename("wav") | |
write_wave(tmp_path, segment, sample_rate) | |
convert_sr(tmp_path, sr, | |
output_path=str(output_path_vad)) | |
else: | |
write_wave(str(output_path_vad), segment, sample_rate) | |
final_output = str(output_path_vad) | |
segment, _ = torchaudio.load(final_output) | |
n_frames = segment.size(1) | |
if keep_sample: | |
output_dict["id"].append(row["id"]) | |
output_dict["audio"].append(final_output) | |
output_dict["n_frames"].append(n_frames) | |
output_dict["tgt_text"].append(row["tgt_text"]) | |
output_dict["speaker"].append(row["speaker"]) | |
output_dict["src_text"].append(row["src_text"]) | |
output_dict["snr"].append(snr) | |
out_tsv_path = Path(args.output_dir) / Path(args.audio_manifest).name | |
log.info(f"Saving manifest to {out_tsv_path.as_posix()}") | |
save_df_to_tsv(pd.DataFrame.from_dict(output_dict), out_tsv_path) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--audio-manifest", "-i", required=True, | |
type=str, help="path to the input manifest.") | |
parser.add_argument( | |
"--output-dir", "-o", required=True, type=str, | |
help="path to the output dir. it will contain files after denoising and" | |
" vad" | |
) | |
parser.add_argument("--vad-agg-level", "-a", type=int, default=2, | |
help="the aggresive level of the vad [0-3].") | |
parser.add_argument( | |
"--dry-wet", "-dw", type=float, default=0.01, | |
help="the level of linear interpolation between noisy and enhanced " | |
"files." | |
) | |
parser.add_argument( | |
"--device", "-d", type=str, default="cpu", | |
help="the device to be used for the speech enhancement model: " | |
"cpu | cuda." | |
) | |
parser.add_argument("--denoise", action="store_true", | |
help="apply a denoising") | |
parser.add_argument("--vad", action="store_true", help="apply a VAD") | |
args = parser.parse_args() | |
process(args) | |
if __name__ == "__main__": | |
main() | |