Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import os | |
from pathlib import Path | |
import sys | |
import torchaudio | |
import numpy as np | |
from time import time | |
import torch | |
import typing as tp | |
from omegaconf import OmegaConf | |
from vocos import VocosDecoder | |
from models.soundstream_hubert_new import SoundStream | |
from tqdm import tqdm | |
def build_soundstream_model(config): | |
model = eval(config.generator.name)(**config.generator.config) | |
return model | |
def build_codec_model(config_path, vocal_decoder_path, inst_decoder_path): | |
vocal_decoder = VocosDecoder.from_hparams(config_path=config_path) | |
vocal_decoder.load_state_dict(torch.load(vocal_decoder_path)) | |
inst_decoder = VocosDecoder.from_hparams(config_path=config_path) | |
inst_decoder.load_state_dict(torch.load(inst_decoder_path)) | |
return vocal_decoder, inst_decoder | |
def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], sample_rate: int, rescale: bool = False): | |
limit = 0.99 | |
mx = wav.abs().max() | |
if rescale: | |
wav = wav * min(limit / mx, 1) | |
else: | |
wav = wav.clamp(-limit, limit) | |
path = str(Path(path).with_suffix('.mp3')) | |
torchaudio.save(path, wav, sample_rate=sample_rate) | |
def process_audio(input_file, output_file, rescale, args, decoder, soundstream): | |
compressed = np.load(input_file, allow_pickle=True).astype(np.int16) | |
print(f"Processing {input_file}") | |
print(f"Compressed shape: {compressed.shape}") | |
args.bw = float(4) | |
compressed = torch.as_tensor(compressed, dtype=torch.long).unsqueeze(1) | |
compressed = soundstream.get_embed(compressed.to(f"cuda:{args.cuda_idx}")) | |
compressed = torch.tensor(compressed).to(f"cuda:{args.cuda_idx}") | |
start_time = time() | |
with torch.no_grad(): | |
decoder.eval() | |
decoder = decoder.to(f"cuda:{args.cuda_idx}") | |
out = decoder(compressed) | |
out = out.detach().cpu() | |
duration = time() - start_time | |
rtf = (out.shape[1] / 44100.0) / duration | |
print(f"Decoded in {duration:.2f}s ({rtf:.2f}x RTF)") | |
os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
save_audio(out, output_file, 44100, rescale=rescale) | |
print(f"Saved: {output_file}") | |
return out | |
def find_matching_pairs(input_folder): | |
if str(input_folder).endswith('.lst'): # Convert to string | |
with open(input_folder, 'r') as file: | |
files = [line.strip() for line in file if line.strip()] | |
else: | |
files = list(Path(input_folder).glob('*.npy')) | |
print(f"found {len(files)} npy.") | |
instrumental_files = {} | |
vocal_files = {} | |
for file in files: | |
if not isinstance(file, Path): | |
file = Path(file) | |
name = file.stem | |
if 'instrumental' in name.lower(): | |
base_name = name.lower().replace('instrumental', '')#.strip('_') | |
instrumental_files[base_name] = file | |
elif 'vocal' in name.lower(): | |
# base_name = name.lower().replace('vocal', '').strip('_') | |
last_index = name.lower().rfind('vocal') | |
if last_index != -1: | |
# Create a new string with the last 'vocal' removed | |
base_name = name.lower()[:last_index] + name.lower()[last_index + len('vocal'):] | |
else: | |
base_name = name.lower() | |
vocal_files[base_name] = file | |
# Find matching pairs | |
pairs = [] | |
for base_name in instrumental_files.keys(): | |
if base_name in vocal_files: | |
pairs.append(( | |
instrumental_files[base_name], | |
vocal_files[base_name], | |
base_name | |
)) | |
return pairs | |
def main(): | |
parser = argparse.ArgumentParser(description='High fidelity neural audio codec using Vocos decoder.') | |
parser.add_argument('--input_folder', type=Path, required=True, help='Input folder containing NPY files.') | |
parser.add_argument('--output_base', type=Path, required=True, help='Base output folder.') | |
parser.add_argument('--resume_path', type=str, default='./final_ckpt/ckpt_00360000.pth', help='Path to model checkpoint.') | |
parser.add_argument('--config_path', type=str, default='./config.yaml', help='Path to Vocos config file.') | |
parser.add_argument('--vocal_decoder_path', type=str, default='/aifs4su/mmcode/codeclm/xcodec_mini_infer_newdecoder/decoders/decoder_131000.pth', help='Path to Vocos decoder weights.') | |
parser.add_argument('--inst_decoder_path', type=str, default='/aifs4su/mmcode/codeclm/xcodec_mini_infer_newdecoder/decoders/decoder_151000.pth', help='Path to Vocos decoder weights.') | |
parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.') | |
args = parser.parse_args() | |
# Validate inputs | |
if not args.input_folder.exists(): | |
sys.exit(f"Input folder {args.input_folder} does not exist.") | |
if not os.path.isfile(args.config_path): | |
sys.exit(f"{args.config_path} file does not exist.") | |
# if not os.path.isfile(args.decoder_path): | |
# sys.exit(f"{args.decoder_path} file does not exist.") | |
# Create output directories | |
mix_dir = args.output_base / 'mix' | |
stems_dir = args.output_base / 'stems' | |
os.makedirs(mix_dir, exist_ok=True) | |
os.makedirs(stems_dir, exist_ok=True) | |
# Initialize models | |
config_ss = OmegaConf.load("./final_ckpt/config.yaml") | |
soundstream = build_soundstream_model(config_ss) | |
parameter_dict = torch.load(args.resume_path) | |
soundstream.load_state_dict(parameter_dict['codec_model']) | |
soundstream.eval() | |
vocal_decoder, inst_decoder = build_codec_model(args.config_path, args.vocal_decoder_path, args.inst_decoder_path) | |
# Find and process matching pairs | |
pairs = find_matching_pairs(args.input_folder) | |
print(f"Found {len(pairs)} matching pairs") | |
pairs = [p for p in pairs if not os.path.exists(mix_dir / f'{p[2]}.mp3')] | |
print(f"{len(pairs)} to reconstruct...") | |
for instrumental_file, vocal_file, base_name in tqdm(pairs): | |
print(f"\nProcessing pair: {base_name}") | |
# Create stems directory for this song | |
song_stems_dir = stems_dir / base_name | |
os.makedirs(song_stems_dir, exist_ok=True) | |
try: | |
# Process instrumental | |
instrumental_output = process_audio( | |
instrumental_file, | |
song_stems_dir / 'instrumental.mp3', | |
args.rescale, | |
args, | |
inst_decoder, | |
soundstream | |
) | |
# Process vocal | |
vocal_output = process_audio( | |
vocal_file, | |
song_stems_dir / 'vocal.mp3', | |
args.rescale, | |
args, | |
vocal_decoder, | |
soundstream | |
) | |
except IndexError as e: | |
print(e) | |
continue | |
# Create and save mix | |
try: | |
mix_output = instrumental_output + vocal_output | |
save_audio(mix_output, mix_dir / f'{base_name}.mp3', 44100, args.rescale) | |
print(f"Created mix: {mix_dir / f'{base_name}.mp3'}") | |
except RuntimeError as e: | |
print(e) | |
print(f"mix {base_name} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}") | |
if __name__ == '__main__': | |
main() | |
# Example Usage | |
# python reconstruct_separately.py --input_folder test_samples --output_base test |