Spaces:
Running
on
Zero
Running
on
Zero
Upload ./vocoder.py with huggingface_hub
Browse files- vocoder.py +184 -0
vocoder.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
import sys
|
5 |
+
import torchaudio
|
6 |
+
import numpy as np
|
7 |
+
from time import time
|
8 |
+
import torch
|
9 |
+
import typing as tp
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from vocos import VocosDecoder
|
12 |
+
from models.soundstream_hubert_new import SoundStream
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
def build_soundstream_model(config):
|
16 |
+
model = eval(config.generator.name)(**config.generator.config)
|
17 |
+
return model
|
18 |
+
|
19 |
+
def build_codec_model(config_path, vocal_decoder_path, inst_decoder_path):
|
20 |
+
vocal_decoder = VocosDecoder.from_hparams(config_path=config_path)
|
21 |
+
vocal_decoder.load_state_dict(torch.load(vocal_decoder_path))
|
22 |
+
inst_decoder = VocosDecoder.from_hparams(config_path=config_path)
|
23 |
+
inst_decoder.load_state_dict(torch.load(inst_decoder_path))
|
24 |
+
return vocal_decoder, inst_decoder
|
25 |
+
|
26 |
+
def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], sample_rate: int, rescale: bool = False):
|
27 |
+
limit = 0.99
|
28 |
+
mx = wav.abs().max()
|
29 |
+
if rescale:
|
30 |
+
wav = wav * min(limit / mx, 1)
|
31 |
+
else:
|
32 |
+
wav = wav.clamp(-limit, limit)
|
33 |
+
|
34 |
+
path = str(Path(path).with_suffix('.mp3'))
|
35 |
+
torchaudio.save(path, wav, sample_rate=sample_rate)
|
36 |
+
|
37 |
+
def process_audio(input_file, output_file, rescale, args, decoder, soundstream):
|
38 |
+
compressed = np.load(input_file, allow_pickle=True).astype(np.int16)
|
39 |
+
print(f"Processing {input_file}")
|
40 |
+
print(f"Compressed shape: {compressed.shape}")
|
41 |
+
|
42 |
+
args.bw = float(4)
|
43 |
+
compressed = torch.as_tensor(compressed, dtype=torch.long).unsqueeze(1)
|
44 |
+
compressed = soundstream.get_embed(compressed.to(f"cuda:{args.cuda_idx}"))
|
45 |
+
compressed = torch.tensor(compressed).to(f"cuda:{args.cuda_idx}")
|
46 |
+
|
47 |
+
start_time = time()
|
48 |
+
with torch.no_grad():
|
49 |
+
decoder.eval()
|
50 |
+
decoder = decoder.to(f"cuda:{args.cuda_idx}")
|
51 |
+
out = decoder(compressed)
|
52 |
+
out = out.detach().cpu()
|
53 |
+
duration = time() - start_time
|
54 |
+
rtf = (out.shape[1] / 44100.0) / duration
|
55 |
+
print(f"Decoded in {duration:.2f}s ({rtf:.2f}x RTF)")
|
56 |
+
|
57 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
58 |
+
save_audio(out, output_file, 44100, rescale=rescale)
|
59 |
+
print(f"Saved: {output_file}")
|
60 |
+
return out
|
61 |
+
|
62 |
+
def find_matching_pairs(input_folder):
|
63 |
+
if str(input_folder).endswith('.lst'): # Convert to string
|
64 |
+
with open(input_folder, 'r') as file:
|
65 |
+
files = [line.strip() for line in file if line.strip()]
|
66 |
+
else:
|
67 |
+
files = list(Path(input_folder).glob('*.npy'))
|
68 |
+
print(f"found {len(files)} npy.")
|
69 |
+
instrumental_files = {}
|
70 |
+
vocal_files = {}
|
71 |
+
|
72 |
+
for file in files:
|
73 |
+
if not isinstance(file, Path):
|
74 |
+
file = Path(file)
|
75 |
+
name = file.stem
|
76 |
+
if 'instrumental' in name.lower():
|
77 |
+
base_name = name.lower().replace('instrumental', '')#.strip('_')
|
78 |
+
instrumental_files[base_name] = file
|
79 |
+
elif 'vocal' in name.lower():
|
80 |
+
# base_name = name.lower().replace('vocal', '').strip('_')
|
81 |
+
last_index = name.lower().rfind('vocal')
|
82 |
+
if last_index != -1:
|
83 |
+
# Create a new string with the last 'vocal' removed
|
84 |
+
base_name = name.lower()[:last_index] + name.lower()[last_index + len('vocal'):]
|
85 |
+
else:
|
86 |
+
base_name = name.lower()
|
87 |
+
vocal_files[base_name] = file
|
88 |
+
|
89 |
+
# Find matching pairs
|
90 |
+
pairs = []
|
91 |
+
for base_name in instrumental_files.keys():
|
92 |
+
if base_name in vocal_files:
|
93 |
+
pairs.append((
|
94 |
+
instrumental_files[base_name],
|
95 |
+
vocal_files[base_name],
|
96 |
+
base_name
|
97 |
+
))
|
98 |
+
|
99 |
+
return pairs
|
100 |
+
|
101 |
+
def main():
|
102 |
+
parser = argparse.ArgumentParser(description='High fidelity neural audio codec using Vocos decoder.')
|
103 |
+
parser.add_argument('--input_folder', type=Path, required=True, help='Input folder containing NPY files.')
|
104 |
+
parser.add_argument('--output_base', type=Path, required=True, help='Base output folder.')
|
105 |
+
parser.add_argument('--resume_path', type=str, default='./final_ckpt/ckpt_00360000.pth', help='Path to model checkpoint.')
|
106 |
+
parser.add_argument('--config_path', type=str, default='./config.yaml', help='Path to Vocos config file.')
|
107 |
+
parser.add_argument('--vocal_decoder_path', type=str, default='/aifs4su/mmcode/codeclm/xcodec_mini_infer_newdecoder/decoders/decoder_131000.pth', help='Path to Vocos decoder weights.')
|
108 |
+
parser.add_argument('--inst_decoder_path', type=str, default='/aifs4su/mmcode/codeclm/xcodec_mini_infer_newdecoder/decoders/decoder_151000.pth', help='Path to Vocos decoder weights.')
|
109 |
+
parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.')
|
110 |
+
args = parser.parse_args()
|
111 |
+
|
112 |
+
# Validate inputs
|
113 |
+
if not args.input_folder.exists():
|
114 |
+
sys.exit(f"Input folder {args.input_folder} does not exist.")
|
115 |
+
if not os.path.isfile(args.config_path):
|
116 |
+
sys.exit(f"{args.config_path} file does not exist.")
|
117 |
+
# if not os.path.isfile(args.decoder_path):
|
118 |
+
# sys.exit(f"{args.decoder_path} file does not exist.")
|
119 |
+
|
120 |
+
# Create output directories
|
121 |
+
mix_dir = args.output_base / 'mix'
|
122 |
+
stems_dir = args.output_base / 'stems'
|
123 |
+
os.makedirs(mix_dir, exist_ok=True)
|
124 |
+
os.makedirs(stems_dir, exist_ok=True)
|
125 |
+
|
126 |
+
# Initialize models
|
127 |
+
config_ss = OmegaConf.load("./final_ckpt/config.yaml")
|
128 |
+
soundstream = build_soundstream_model(config_ss)
|
129 |
+
parameter_dict = torch.load(args.resume_path)
|
130 |
+
soundstream.load_state_dict(parameter_dict['codec_model'])
|
131 |
+
soundstream.eval()
|
132 |
+
|
133 |
+
vocal_decoder, inst_decoder = build_codec_model(args.config_path, args.vocal_decoder_path, args.inst_decoder_path)
|
134 |
+
|
135 |
+
# Find and process matching pairs
|
136 |
+
pairs = find_matching_pairs(args.input_folder)
|
137 |
+
print(f"Found {len(pairs)} matching pairs")
|
138 |
+
pairs = [p for p in pairs if not os.path.exists(mix_dir / f'{p[2]}.mp3')]
|
139 |
+
print(f"{len(pairs)} to reconstruct...")
|
140 |
+
|
141 |
+
for instrumental_file, vocal_file, base_name in tqdm(pairs):
|
142 |
+
print(f"\nProcessing pair: {base_name}")
|
143 |
+
# Create stems directory for this song
|
144 |
+
song_stems_dir = stems_dir / base_name
|
145 |
+
os.makedirs(song_stems_dir, exist_ok=True)
|
146 |
+
|
147 |
+
try:
|
148 |
+
# Process instrumental
|
149 |
+
instrumental_output = process_audio(
|
150 |
+
instrumental_file,
|
151 |
+
song_stems_dir / 'instrumental.mp3',
|
152 |
+
args.rescale,
|
153 |
+
args,
|
154 |
+
inst_decoder,
|
155 |
+
soundstream
|
156 |
+
)
|
157 |
+
|
158 |
+
# Process vocal
|
159 |
+
vocal_output = process_audio(
|
160 |
+
vocal_file,
|
161 |
+
song_stems_dir / 'vocal.mp3',
|
162 |
+
args.rescale,
|
163 |
+
args,
|
164 |
+
vocal_decoder,
|
165 |
+
soundstream
|
166 |
+
)
|
167 |
+
except IndexError as e:
|
168 |
+
print(e)
|
169 |
+
continue
|
170 |
+
|
171 |
+
# Create and save mix
|
172 |
+
try:
|
173 |
+
mix_output = instrumental_output + vocal_output
|
174 |
+
save_audio(mix_output, mix_dir / f'{base_name}.mp3', 44100, args.rescale)
|
175 |
+
print(f"Created mix: {mix_dir / f'{base_name}.mp3'}")
|
176 |
+
except RuntimeError as e:
|
177 |
+
print(e)
|
178 |
+
print(f"mix {base_name} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
|
179 |
+
|
180 |
+
if __name__ == '__main__':
|
181 |
+
main()
|
182 |
+
|
183 |
+
# Example Usage
|
184 |
+
# python reconstruct_separately.py --input_folder test_samples --output_base test
|