File size: 2,767 Bytes
8c92a11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
import torch
from models.vocoders.vocoder_inference import VocoderInference
from utils.util import load_config
def build_inference(args, cfg, infer_type="infer_from_dataset"):
supported_inference = {
"GANVocoder": VocoderInference,
"DiffusionVocoder": VocoderInference,
}
inference_class = supported_inference[cfg.model_type]
return inference_class(args, cfg, infer_type)
def cuda_relevant(deterministic=False):
torch.cuda.empty_cache()
# TF32 on Ampere and above
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
# Deterministic
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
torch.use_deterministic_algorithms(deterministic)
def build_parser():
r"""Build argument parser for inference.py.
Anything else should be put in an extra config YAML file.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
type=str,
required=True,
help="JSON/YAML file for configurations.",
)
parser.add_argument(
"--infer_mode",
type=str,
required=None,
)
parser.add_argument(
"--infer_datasets",
nargs="+",
default=None,
)
parser.add_argument(
"--feature_folder",
type=str,
default=None,
)
parser.add_argument(
"--audio_folder",
type=str,
default=None,
)
parser.add_argument(
"--vocoder_dir",
type=str,
required=True,
help="Vocoder checkpoint directory. Searching behavior is the same as "
"the acoustics one.",
)
parser.add_argument(
"--output_dir",
type=str,
default="result",
help="Output directory. Default: ./result",
)
parser.add_argument(
"--log_level",
type=str,
default="warning",
help="Logging level. Default: warning",
)
parser.add_argument(
"--keep_cache",
action="store_true",
default=False,
help="Keep cache files. Only applicable to inference from files.",
)
return parser
def main():
# Parse arguments
args = build_parser().parse_args()
# Parse config
cfg = load_config(args.config)
# CUDA settings
cuda_relevant()
# Build inference
trainer = build_inference(args, cfg, args.infer_mode)
# Run inference
trainer.inference()
if __name__ == "__main__":
main()
|