|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
import torch.hub |
|
|
|
from .demucs import Demucs |
|
from .utils import deserialize_model |
|
|
|
logger = logging.getLogger(__name__) |
|
ROOT = "https://dl.fbaipublicfiles.com/adiyoss/denoiser/" |
|
DNS_48_URL = ROOT + "dns48-11decc9d8e3f0998.th" |
|
DNS_64_URL = ROOT + "dns64-a7761ff99a7d5bb6.th" |
|
MASTER_64_URL = ROOT + "master64-8a5dfb4bb92753dd.th" |
|
|
|
|
|
def _demucs(pretrained, url, **kwargs): |
|
model = Demucs(**kwargs) |
|
if pretrained: |
|
state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu') |
|
model.load_state_dict(state_dict) |
|
return model |
|
|
|
|
|
def dns48(pretrained=True): |
|
return _demucs(pretrained, DNS_48_URL, hidden=48) |
|
|
|
|
|
def dns64(pretrained=True): |
|
return _demucs(pretrained, DNS_64_URL, hidden=64) |
|
|
|
|
|
def master64(pretrained=True): |
|
return _demucs(pretrained, MASTER_64_URL, hidden=64) |
|
|
|
|
|
def add_model_flags(parser): |
|
group = parser.add_mutually_exclusive_group(required=False) |
|
group.add_argument("-m", "--model_path", help="Path to local trained model.") |
|
group.add_argument("--dns48", action="store_true", |
|
help="Use pre-trained real time H=48 model trained on DNS.") |
|
group.add_argument("--dns64", action="store_true", |
|
help="Use pre-trained real time H=64 model trained on DNS.") |
|
group.add_argument("--master64", action="store_true", |
|
help="Use pre-trained real time H=64 model trained on DNS and Valentini.") |
|
|
|
|
|
def get_model(args): |
|
""" |
|
Load local model package or torchhub pre-trained model. |
|
""" |
|
if args.model_path: |
|
logger.info("Loading model from %s", args.model_path) |
|
model = Demucs(hidden=64) |
|
pkg = torch.load(args.model_path, map_location='cpu') |
|
model.load_state_dict(pkg) |
|
elif args.dns64: |
|
logger.info("Loading pre-trained real time H=64 model trained on DNS.") |
|
model = dns64() |
|
elif args.master64: |
|
logger.info("Loading pre-trained real time H=64 model trained on DNS and Valentini.") |
|
model = master64() |
|
else: |
|
logger.info("Loading pre-trained real time H=48 model trained on DNS.") |
|
model = dns48() |
|
logger.debug(model) |
|
return model |
|
|