File size: 5,778 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
#!/usr/bin/env python3
import configargparse
from distutils.util import strtobool
import logging
import os
import random
import sys

import numpy as np

from espnet.asr.pytorch_backend.asr import enhance


# NOTE: you need this func to generate our sphinx doc
def get_parser():
    parser = configargparse.ArgumentParser(
        description="Enhance noisy speech for speech recognition",
        config_file_parser_class=configargparse.YAMLConfigFileParser,
        formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
    )
    # general configuration
    parser.add("--config", is_config_file=True, help="config file path")
    parser.add(
        "--config2",
        is_config_file=True,
        help="second config file path that overwrites the settings in `--config`.",
    )
    parser.add(
        "--config3",
        is_config_file=True,
        help="third config file path that overwrites the settings "
        "in `--config` and `--config2`.",
    )

    parser.add_argument("--ngpu", default=0, type=int, help="Number of GPUs")
    parser.add_argument(
        "--backend",
        default="chainer",
        type=str,
        choices=["chainer", "pytorch"],
        help="Backend library",
    )
    parser.add_argument("--debugmode", default=1, type=int, help="Debugmode")
    parser.add_argument("--seed", default=1, type=int, help="Random seed")
    parser.add_argument("--verbose", "-V", default=1, type=int, help="Verbose option")
    parser.add_argument(
        "--batchsize",
        default=1,
        type=int,
        help="Batch size for beam search (0: means no batch processing)",
    )
    parser.add_argument(
        "--preprocess-conf",
        type=str,
        default=None,
        help="The configuration file for the pre-processing",
    )
    # task related
    parser.add_argument(
        "--recog-json", type=str, help="Filename of recognition data (json)"
    )
    # model (parameter) related
    parser.add_argument(
        "--model", type=str, required=True, help="Model file parameters to read"
    )
    parser.add_argument(
        "--model-conf", type=str, default=None, help="Model config file"
    )

    # Outputs configuration
    parser.add_argument(
        "--enh-wspecifier",
        type=str,
        default=None,
        help="Specify the output way for enhanced speech."
        "e.g. ark,scp:outdir,wav.scp",
    )
    parser.add_argument(
        "--enh-filetype",
        type=str,
        default="sound",
        choices=["mat", "hdf5", "sound.hdf5", "sound"],
        help="Specify the file format for enhanced speech. "
        '"mat" is the matrix format in kaldi',
    )
    parser.add_argument("--fs", type=int, default=16000, help="The sample frequency")
    parser.add_argument(
        "--keep-length",
        type=strtobool,
        default=True,
        help="Adjust the output length to match " "with the input for enhanced speech",
    )
    parser.add_argument(
        "--image-dir", type=str, default=None, help="The directory saving the images."
    )
    parser.add_argument(
        "--num-images",
        type=int,
        default=20,
        help="The number of images files to be saved. "
        "If negative, all samples are to be saved.",
    )

    # IStft
    parser.add_argument(
        "--apply-istft",
        type=strtobool,
        default=True,
        help="Apply istft to the output from the network",
    )
    parser.add_argument(
        "--istft-win-length",
        type=int,
        default=512,
        help="The window length for istft. "
        "This option is ignored "
        "if stft is found in the preprocess-conf",
    )
    parser.add_argument(
        "--istft-n-shift",
        type=str,
        default=256,
        help="The window type for istft. "
        "This option is ignored "
        "if stft is found in the preprocess-conf",
    )
    parser.add_argument(
        "--istft-window",
        type=str,
        default="hann",
        help="The window type for istft. "
        "This option is ignored "
        "if stft is found in the preprocess-conf",
    )
    return parser


def main(args):
    parser = get_parser()
    args = parser.parse_args(args)

    # logging info
    if args.verbose == 1:
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    elif args.verbose == 2:
        logging.basicConfig(
            level=logging.DEBUG,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # check CUDA_VISIBLE_DEVICES
    if args.ngpu > 0:
        cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
        if cvd is None:
            logging.warning("CUDA_VISIBLE_DEVICES is not set.")
        elif args.ngpu != len(cvd.split(",")):
            logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
            sys.exit(1)

        # TODO(kamo): support of multiple GPUs
        if args.ngpu > 1:
            logging.error("The program only supports ngpu=1.")
            sys.exit(1)

    # display PYTHONPATH
    logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))

    # seed setting
    random.seed(args.seed)
    np.random.seed(args.seed)
    logging.info("set random seed = %d" % args.seed)

    # recog
    logging.info("backend = " + args.backend)
    if args.backend == "pytorch":
        enhance(args)
    else:
        raise ValueError("Only pytorch is supported.")


if __name__ == "__main__":
    main(sys.argv[1:])