KingNish commited on
Commit
54616be
·
verified ·
1 Parent(s): 1a38965

Upload ./vocoder.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vocoder.py +184 -0
vocoder.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+ import sys
5
+ import torchaudio
6
+ import numpy as np
7
+ from time import time
8
+ import torch
9
+ import typing as tp
10
+ from omegaconf import OmegaConf
11
+ from vocos import VocosDecoder
12
+ from models.soundstream_hubert_new import SoundStream
13
+ from tqdm import tqdm
14
+
15
+ def build_soundstream_model(config):
16
+ model = eval(config.generator.name)(**config.generator.config)
17
+ return model
18
+
19
+ def build_codec_model(config_path, vocal_decoder_path, inst_decoder_path):
20
+ vocal_decoder = VocosDecoder.from_hparams(config_path=config_path)
21
+ vocal_decoder.load_state_dict(torch.load(vocal_decoder_path))
22
+ inst_decoder = VocosDecoder.from_hparams(config_path=config_path)
23
+ inst_decoder.load_state_dict(torch.load(inst_decoder_path))
24
+ return vocal_decoder, inst_decoder
25
+
26
+ def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], sample_rate: int, rescale: bool = False):
27
+ limit = 0.99
28
+ mx = wav.abs().max()
29
+ if rescale:
30
+ wav = wav * min(limit / mx, 1)
31
+ else:
32
+ wav = wav.clamp(-limit, limit)
33
+
34
+ path = str(Path(path).with_suffix('.mp3'))
35
+ torchaudio.save(path, wav, sample_rate=sample_rate)
36
+
37
+ def process_audio(input_file, output_file, rescale, args, decoder, soundstream):
38
+ compressed = np.load(input_file, allow_pickle=True).astype(np.int16)
39
+ print(f"Processing {input_file}")
40
+ print(f"Compressed shape: {compressed.shape}")
41
+
42
+ args.bw = float(4)
43
+ compressed = torch.as_tensor(compressed, dtype=torch.long).unsqueeze(1)
44
+ compressed = soundstream.get_embed(compressed.to(f"cuda:{args.cuda_idx}"))
45
+ compressed = torch.tensor(compressed).to(f"cuda:{args.cuda_idx}")
46
+
47
+ start_time = time()
48
+ with torch.no_grad():
49
+ decoder.eval()
50
+ decoder = decoder.to(f"cuda:{args.cuda_idx}")
51
+ out = decoder(compressed)
52
+ out = out.detach().cpu()
53
+ duration = time() - start_time
54
+ rtf = (out.shape[1] / 44100.0) / duration
55
+ print(f"Decoded in {duration:.2f}s ({rtf:.2f}x RTF)")
56
+
57
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
58
+ save_audio(out, output_file, 44100, rescale=rescale)
59
+ print(f"Saved: {output_file}")
60
+ return out
61
+
62
+ def find_matching_pairs(input_folder):
63
+ if str(input_folder).endswith('.lst'): # Convert to string
64
+ with open(input_folder, 'r') as file:
65
+ files = [line.strip() for line in file if line.strip()]
66
+ else:
67
+ files = list(Path(input_folder).glob('*.npy'))
68
+ print(f"found {len(files)} npy.")
69
+ instrumental_files = {}
70
+ vocal_files = {}
71
+
72
+ for file in files:
73
+ if not isinstance(file, Path):
74
+ file = Path(file)
75
+ name = file.stem
76
+ if 'instrumental' in name.lower():
77
+ base_name = name.lower().replace('instrumental', '')#.strip('_')
78
+ instrumental_files[base_name] = file
79
+ elif 'vocal' in name.lower():
80
+ # base_name = name.lower().replace('vocal', '').strip('_')
81
+ last_index = name.lower().rfind('vocal')
82
+ if last_index != -1:
83
+ # Create a new string with the last 'vocal' removed
84
+ base_name = name.lower()[:last_index] + name.lower()[last_index + len('vocal'):]
85
+ else:
86
+ base_name = name.lower()
87
+ vocal_files[base_name] = file
88
+
89
+ # Find matching pairs
90
+ pairs = []
91
+ for base_name in instrumental_files.keys():
92
+ if base_name in vocal_files:
93
+ pairs.append((
94
+ instrumental_files[base_name],
95
+ vocal_files[base_name],
96
+ base_name
97
+ ))
98
+
99
+ return pairs
100
+
101
+ def main():
102
+ parser = argparse.ArgumentParser(description='High fidelity neural audio codec using Vocos decoder.')
103
+ parser.add_argument('--input_folder', type=Path, required=True, help='Input folder containing NPY files.')
104
+ parser.add_argument('--output_base', type=Path, required=True, help='Base output folder.')
105
+ parser.add_argument('--resume_path', type=str, default='./final_ckpt/ckpt_00360000.pth', help='Path to model checkpoint.')
106
+ parser.add_argument('--config_path', type=str, default='./config.yaml', help='Path to Vocos config file.')
107
+ parser.add_argument('--vocal_decoder_path', type=str, default='/aifs4su/mmcode/codeclm/xcodec_mini_infer_newdecoder/decoders/decoder_131000.pth', help='Path to Vocos decoder weights.')
108
+ parser.add_argument('--inst_decoder_path', type=str, default='/aifs4su/mmcode/codeclm/xcodec_mini_infer_newdecoder/decoders/decoder_151000.pth', help='Path to Vocos decoder weights.')
109
+ parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.')
110
+ args = parser.parse_args()
111
+
112
+ # Validate inputs
113
+ if not args.input_folder.exists():
114
+ sys.exit(f"Input folder {args.input_folder} does not exist.")
115
+ if not os.path.isfile(args.config_path):
116
+ sys.exit(f"{args.config_path} file does not exist.")
117
+ # if not os.path.isfile(args.decoder_path):
118
+ # sys.exit(f"{args.decoder_path} file does not exist.")
119
+
120
+ # Create output directories
121
+ mix_dir = args.output_base / 'mix'
122
+ stems_dir = args.output_base / 'stems'
123
+ os.makedirs(mix_dir, exist_ok=True)
124
+ os.makedirs(stems_dir, exist_ok=True)
125
+
126
+ # Initialize models
127
+ config_ss = OmegaConf.load("./final_ckpt/config.yaml")
128
+ soundstream = build_soundstream_model(config_ss)
129
+ parameter_dict = torch.load(args.resume_path)
130
+ soundstream.load_state_dict(parameter_dict['codec_model'])
131
+ soundstream.eval()
132
+
133
+ vocal_decoder, inst_decoder = build_codec_model(args.config_path, args.vocal_decoder_path, args.inst_decoder_path)
134
+
135
+ # Find and process matching pairs
136
+ pairs = find_matching_pairs(args.input_folder)
137
+ print(f"Found {len(pairs)} matching pairs")
138
+ pairs = [p for p in pairs if not os.path.exists(mix_dir / f'{p[2]}.mp3')]
139
+ print(f"{len(pairs)} to reconstruct...")
140
+
141
+ for instrumental_file, vocal_file, base_name in tqdm(pairs):
142
+ print(f"\nProcessing pair: {base_name}")
143
+ # Create stems directory for this song
144
+ song_stems_dir = stems_dir / base_name
145
+ os.makedirs(song_stems_dir, exist_ok=True)
146
+
147
+ try:
148
+ # Process instrumental
149
+ instrumental_output = process_audio(
150
+ instrumental_file,
151
+ song_stems_dir / 'instrumental.mp3',
152
+ args.rescale,
153
+ args,
154
+ inst_decoder,
155
+ soundstream
156
+ )
157
+
158
+ # Process vocal
159
+ vocal_output = process_audio(
160
+ vocal_file,
161
+ song_stems_dir / 'vocal.mp3',
162
+ args.rescale,
163
+ args,
164
+ vocal_decoder,
165
+ soundstream
166
+ )
167
+ except IndexError as e:
168
+ print(e)
169
+ continue
170
+
171
+ # Create and save mix
172
+ try:
173
+ mix_output = instrumental_output + vocal_output
174
+ save_audio(mix_output, mix_dir / f'{base_name}.mp3', 44100, args.rescale)
175
+ print(f"Created mix: {mix_dir / f'{base_name}.mp3'}")
176
+ except RuntimeError as e:
177
+ print(e)
178
+ print(f"mix {base_name} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
179
+
180
+ if __name__ == '__main__':
181
+ main()
182
+
183
+ # Example Usage
184
+ # python reconstruct_separately.py --input_folder test_samples --output_base test