File size: 7,508 Bytes
7bc29af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import sys
from pathlib import Path
import subprocess
import julius
import torch as th
import torchaudio as ta
from .audio import AudioFile, convert_audio_channels
from .pretrained import is_pretrained, load_pretrained
from .utils import apply_model, load_model
def load_track(track, device, audio_channels, samplerate):
errors = {}
wav = None
try:
wav = AudioFile(track).read(
streams=0,
samplerate=samplerate,
channels=audio_channels).to(device)
except FileNotFoundError:
errors['ffmpeg'] = 'Ffmpeg is not installed.'
except subprocess.CalledProcessError:
errors['ffmpeg'] = 'FFmpeg could not read the file.'
if wav is None:
try:
wav, sr = ta.load(str(track))
except RuntimeError as err:
errors['torchaudio'] = err.args[0]
else:
wav = convert_audio_channels(wav, audio_channels)
wav = wav.to(device)
wav = julius.resample_frac(wav, sr, samplerate)
if wav is None:
print(f"Could not load file {track}. "
"Maybe it is not a supported file format? ")
for backend, error in errors.items():
print(f"When trying to load using {backend}, got the following error: {error}")
sys.exit(1)
return wav
def encode_mp3(wav, path, bitrate=320, samplerate=44100, channels=2, verbose=False):
try:
import lameenc
except ImportError:
print("Failed to call lame encoder. Maybe it is not installed? "
"On windows, run `python.exe -m pip install -U lameenc`, "
"on OSX/Linux, run `python3 -m pip install -U lameenc`, "
"then try again.", file=sys.stderr)
sys.exit(1)
encoder = lameenc.Encoder()
encoder.set_bit_rate(bitrate)
encoder.set_in_sample_rate(samplerate)
encoder.set_channels(channels)
encoder.set_quality(2) # 2-highest, 7-fastest
if not verbose:
encoder.silence()
wav = wav.transpose(0, 1).numpy()
mp3_data = encoder.encode(wav.tobytes())
mp3_data += encoder.flush()
with open(path, "wb") as f:
f.write(mp3_data)
def main():
parser = argparse.ArgumentParser("demucs.separate",
description="Separate the sources for the given tracks")
parser.add_argument("tracks", nargs='+', type=Path, default=[], help='Path to tracks')
parser.add_argument("-n",
"--name",
default="demucs_quantized",
help="Model name. See README.md for the list of pretrained models. "
"Default is demucs_quantized.")
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument("-o",
"--out",
type=Path,
default=Path("separated"),
help="Folder where to put extracted tracks. A subfolder "
"with the model name will be created.")
parser.add_argument("--models",
type=Path,
default=Path("models"),
help="Path to trained models. "
"Also used to store downloaded pretrained models")
parser.add_argument("-d",
"--device",
default="cuda" if th.cuda.is_available() else "cpu",
help="Device to use, default is cuda if available else cpu")
parser.add_argument("--shifts",
default=0,
type=int,
help="Number of random shifts for equivariant stabilization."
"Increase separation time but improves quality for Demucs. 10 was used "
"in the original paper.")
parser.add_argument("--overlap",
default=0.25,
type=float,
help="Overlap between the splits.")
parser.add_argument("--no-split",
action="store_false",
dest="split",
default=True,
help="Doesn't split audio in chunks. This can use large amounts of memory.")
parser.add_argument("--float32",
action="store_true",
help="Convert the output wavefile to use pcm f32 format instead of s16. "
"This should not make a difference if you just plan on listening to the "
"audio but might be needed to compute exactly metrics like SDR etc.")
parser.add_argument("--int16",
action="store_false",
dest="float32",
help="Opposite of --float32, here for compatibility.")
parser.add_argument("--mp3", action="store_true",
help="Convert the output wavs to mp3.")
parser.add_argument("--mp3-bitrate",
default=320,
type=int,
help="Bitrate of converted mp3.")
args = parser.parse_args()
name = args.name + ".th"
model_path = args.models / name
if model_path.is_file():
model = load_model(model_path)
else:
if is_pretrained(args.name):
model = load_pretrained(args.name)
else:
print(f"No pre-trained model {args.name}", file=sys.stderr)
sys.exit(1)
model.to(args.device)
out = args.out / args.name
out.mkdir(parents=True, exist_ok=True)
print(f"Separated tracks will be stored in {out.resolve()}")
for track in args.tracks:
if not track.exists():
print(
f"File {track} does not exist. If the path contains spaces, "
"please try again after surrounding the entire path with quotes \"\".",
file=sys.stderr)
continue
print(f"Separating track {track}")
wav = load_track(track, args.device, model.audio_channels, model.samplerate)
ref = wav.mean(0)
wav = (wav - ref.mean()) / ref.std()
sources = apply_model(model, wav, shifts=args.shifts, split=args.split,
overlap=args.overlap, progress=True)
sources = sources * ref.std() + ref.mean()
track_folder = out / track.name.rsplit(".", 1)[0]
track_folder.mkdir(exist_ok=True)
for source, name in zip(sources, model.sources):
source = source / max(1.01 * source.abs().max(), 1)
if args.mp3 or not args.float32:
source = (source * 2**15).clamp_(-2**15, 2**15 - 1).short()
source = source.cpu()
stem = str(track_folder / name)
if args.mp3:
encode_mp3(source, stem + ".mp3",
bitrate=args.mp3_bitrate,
samplerate=model.samplerate,
channels=model.audio_channels,
verbose=args.verbose)
else:
wavname = str(track_folder / f"{name}.wav")
ta.save(wavname, source, sample_rate=model.samplerate)
if __name__ == "__main__":
main()
|