conex / espnet /bin /mt_trans.py
tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
No virus
5.96 kB
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Neural machine translation model decoding script."""
import configargparse
import logging
import os
import random
import sys
import numpy as np
# NOTE: you need this func to generate our sphinx doc
def get_parser():
"""Get default arguments."""
parser = configargparse.ArgumentParser(
description="Translate text from speech "
"using a speech translation model on one CPU or GPU",
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
)
# general configuration
parser.add("--config", is_config_file=True, help="Config file path")
parser.add(
"--config2",
is_config_file=True,
help="Second config file path that overwrites the settings in `--config`",
)
parser.add(
"--config3",
is_config_file=True,
help="Third config file path "
"that overwrites the settings in `--config` and `--config2`",
)
parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs")
parser.add_argument(
"--dtype",
choices=("float16", "float32", "float64"),
default="float32",
help="Float precision (only available in --api v2)",
)
parser.add_argument(
"--backend",
type=str,
default="chainer",
choices=["chainer", "pytorch"],
help="Backend library",
)
parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
parser.add_argument("--seed", type=int, default=1, help="Random seed")
parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option")
parser.add_argument(
"--batchsize",
type=int,
default=1,
help="Batch size for beam search (0: means no batch processing)",
)
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing",
)
parser.add_argument(
"--api",
default="v1",
choices=["v1", "v2"],
help="Beam search APIs "
"v1: Default API. It only supports "
"the ASRInterface.recognize method and DefaultRNNLM. "
"v2: Experimental API. "
"It supports any models that implements ScorerInterface.",
)
# task related
parser.add_argument(
"--trans-json", type=str, help="Filename of translation data (json)"
)
parser.add_argument(
"--result-label",
type=str,
required=True,
help="Filename of result label data (json)",
)
# model (parameter) related
parser.add_argument(
"--model", type=str, required=True, help="Model file parameters to read"
)
parser.add_argument(
"--model-conf", type=str, default=None, help="Model config file"
)
# search related
parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
parser.add_argument("--beam-size", type=int, default=1, help="Beam size")
parser.add_argument("--penalty", type=float, default=0.1, help="Incertion penalty")
parser.add_argument(
"--maxlenratio",
type=float,
default=3.0,
help="""Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths""",
)
parser.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
# multilingual related
parser.add_argument(
"--tgt-lang",
default=False,
type=str,
help="target language ID (e.g., <en>, <de>, and <fr> etc.)",
)
return parser
def main(args):
"""Run the main decoding function."""
parser = get_parser()
args = parser.parse_args(args)
# logging info
if args.verbose == 1:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
elif args.verbose == 2:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.warning("Skip DEBUG/INFO messages")
# check CUDA_VISIBLE_DEVICES
if args.ngpu > 0:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is None:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
elif args.ngpu != len(cvd.split(",")):
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
sys.exit(1)
# TODO(mn5k): support of multiple GPUs
if args.ngpu > 1:
logging.error("The program only supports ngpu=1.")
sys.exit(1)
# display PYTHONPATH
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
# seed setting
random.seed(args.seed)
np.random.seed(args.seed)
logging.info("set random seed = %d" % args.seed)
# trans
logging.info("backend = " + args.backend)
if args.backend == "pytorch":
# Experimental API that supports custom LMs
from espnet.mt.pytorch_backend.mt import trans
if args.dtype != "float32":
raise NotImplementedError(
f"`--dtype {args.dtype}` is only available with `--api v2`"
)
trans(args)
else:
raise ValueError("Only pytorch are supported.")
if __name__ == "__main__":
main(sys.argv[1:])