conex / espnet /bin /asr_enhance.py
tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
5.78 kB
#!/usr/bin/env python3
import configargparse
from distutils.util import strtobool
import logging
import os
import random
import sys
import numpy as np
from espnet.asr.pytorch_backend.asr import enhance
# NOTE: you need this func to generate our sphinx doc
def get_parser():
parser = configargparse.ArgumentParser(
description="Enhance noisy speech for speech recognition",
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
)
# general configuration
parser.add("--config", is_config_file=True, help="config file path")
parser.add(
"--config2",
is_config_file=True,
help="second config file path that overwrites the settings in `--config`.",
)
parser.add(
"--config3",
is_config_file=True,
help="third config file path that overwrites the settings "
"in `--config` and `--config2`.",
)
parser.add_argument("--ngpu", default=0, type=int, help="Number of GPUs")
parser.add_argument(
"--backend",
default="chainer",
type=str,
choices=["chainer", "pytorch"],
help="Backend library",
)
parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
parser.add_argument("--seed", default=1, type=int, help="Random seed")
parser.add_argument("--verbose", "-V", default=1, type=int, help="Verbose option")
parser.add_argument(
"--batchsize",
default=1,
type=int,
help="Batch size for beam search (0: means no batch processing)",
)
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing",
)
# task related
parser.add_argument(
"--recog-json", type=str, help="Filename of recognition data (json)"
)
# model (parameter) related
parser.add_argument(
"--model", type=str, required=True, help="Model file parameters to read"
)
parser.add_argument(
"--model-conf", type=str, default=None, help="Model config file"
)
# Outputs configuration
parser.add_argument(
"--enh-wspecifier",
type=str,
default=None,
help="Specify the output way for enhanced speech."
"e.g. ark,scp:outdir,wav.scp",
)
parser.add_argument(
"--enh-filetype",
type=str,
default="sound",
choices=["mat", "hdf5", "sound.hdf5", "sound"],
help="Specify the file format for enhanced speech. "
'"mat" is the matrix format in kaldi',
)
parser.add_argument("--fs", type=int, default=16000, help="The sample frequency")
parser.add_argument(
"--keep-length",
type=strtobool,
default=True,
help="Adjust the output length to match " "with the input for enhanced speech",
)
parser.add_argument(
"--image-dir", type=str, default=None, help="The directory saving the images."
)
parser.add_argument(
"--num-images",
type=int,
default=20,
help="The number of images files to be saved. "
"If negative, all samples are to be saved.",
)
# IStft
parser.add_argument(
"--apply-istft",
type=strtobool,
default=True,
help="Apply istft to the output from the network",
)
parser.add_argument(
"--istft-win-length",
type=int,
default=512,
help="The window length for istft. "
"This option is ignored "
"if stft is found in the preprocess-conf",
)
parser.add_argument(
"--istft-n-shift",
type=str,
default=256,
help="The window type for istft. "
"This option is ignored "
"if stft is found in the preprocess-conf",
)
parser.add_argument(
"--istft-window",
type=str,
default="hann",
help="The window type for istft. "
"This option is ignored "
"if stft is found in the preprocess-conf",
)
return parser
def main(args):
parser = get_parser()
args = parser.parse_args(args)
# logging info
if args.verbose == 1:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
elif args.verbose == 2:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.warning("Skip DEBUG/INFO messages")
# check CUDA_VISIBLE_DEVICES
if args.ngpu > 0:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is None:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
elif args.ngpu != len(cvd.split(",")):
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
sys.exit(1)
# TODO(kamo): support of multiple GPUs
if args.ngpu > 1:
logging.error("The program only supports ngpu=1.")
sys.exit(1)
# display PYTHONPATH
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
# seed setting
random.seed(args.seed)
np.random.seed(args.seed)
logging.info("set random seed = %d" % args.seed)
# recog
logging.info("backend = " + args.backend)
if args.backend == "pytorch":
enhance(args)
else:
raise ValueError("Only pytorch is supported.")
if __name__ == "__main__":
main(sys.argv[1:])