JustinLin610's picture
first commit
ee21b96
raw
history blame
2.38 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# author: adefossez
import logging
import torch.hub
from .demucs import Demucs
from .utils import deserialize_model
logger = logging.getLogger(__name__)
ROOT = "https://dl.fbaipublicfiles.com/adiyoss/denoiser/"
DNS_48_URL = ROOT + "dns48-11decc9d8e3f0998.th"
DNS_64_URL = ROOT + "dns64-a7761ff99a7d5bb6.th"
MASTER_64_URL = ROOT + "master64-8a5dfb4bb92753dd.th"
def _demucs(pretrained, url, **kwargs):
model = Demucs(**kwargs)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu')
model.load_state_dict(state_dict)
return model
def dns48(pretrained=True):
return _demucs(pretrained, DNS_48_URL, hidden=48)
def dns64(pretrained=True):
return _demucs(pretrained, DNS_64_URL, hidden=64)
def master64(pretrained=True):
return _demucs(pretrained, MASTER_64_URL, hidden=64)
def add_model_flags(parser):
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument(
"-m", "--model_path", help="Path to local trained model."
)
group.add_argument(
"--dns48", action="store_true",
help="Use pre-trained real time H=48 model trained on DNS."
)
group.add_argument(
"--dns64", action="store_true",
help="Use pre-trained real time H=64 model trained on DNS."
)
group.add_argument(
"--master64", action="store_true",
help="Use pre-trained real time H=64 model trained on DNS and Valentini."
)
def get_model(args):
"""
Load local model package or torchhub pre-trained model.
"""
if args.model_path:
logger.info("Loading model from %s", args.model_path)
pkg = torch.load(args.model_path)
model = deserialize_model(pkg)
elif args.dns64:
logger.info("Loading pre-trained real time H=64 model trained on DNS.")
model = dns64()
elif args.master64:
logger.info(
"Loading pre-trained real time H=64 model trained on DNS and Valentini."
)
model = master64()
else:
logger.info("Loading pre-trained real time H=48 model trained on DNS.")
model = dns48()
logger.debug(model)
return model