KingNish's picture
Upload ./vocoder.py with huggingface_hub
54616be verified
raw
history blame
7.39 kB
import argparse
import os
from pathlib import Path
import sys
import torchaudio
import numpy as np
from time import time
import torch
import typing as tp
from omegaconf import OmegaConf
from vocos import VocosDecoder
from models.soundstream_hubert_new import SoundStream
from tqdm import tqdm
def build_soundstream_model(config):
model = eval(config.generator.name)(**config.generator.config)
return model
def build_codec_model(config_path, vocal_decoder_path, inst_decoder_path):
vocal_decoder = VocosDecoder.from_hparams(config_path=config_path)
vocal_decoder.load_state_dict(torch.load(vocal_decoder_path))
inst_decoder = VocosDecoder.from_hparams(config_path=config_path)
inst_decoder.load_state_dict(torch.load(inst_decoder_path))
return vocal_decoder, inst_decoder
def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], sample_rate: int, rescale: bool = False):
limit = 0.99
mx = wav.abs().max()
if rescale:
wav = wav * min(limit / mx, 1)
else:
wav = wav.clamp(-limit, limit)
path = str(Path(path).with_suffix('.mp3'))
torchaudio.save(path, wav, sample_rate=sample_rate)
def process_audio(input_file, output_file, rescale, args, decoder, soundstream):
compressed = np.load(input_file, allow_pickle=True).astype(np.int16)
print(f"Processing {input_file}")
print(f"Compressed shape: {compressed.shape}")
args.bw = float(4)
compressed = torch.as_tensor(compressed, dtype=torch.long).unsqueeze(1)
compressed = soundstream.get_embed(compressed.to(f"cuda:{args.cuda_idx}"))
compressed = torch.tensor(compressed).to(f"cuda:{args.cuda_idx}")
start_time = time()
with torch.no_grad():
decoder.eval()
decoder = decoder.to(f"cuda:{args.cuda_idx}")
out = decoder(compressed)
out = out.detach().cpu()
duration = time() - start_time
rtf = (out.shape[1] / 44100.0) / duration
print(f"Decoded in {duration:.2f}s ({rtf:.2f}x RTF)")
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_audio(out, output_file, 44100, rescale=rescale)
print(f"Saved: {output_file}")
return out
def find_matching_pairs(input_folder):
if str(input_folder).endswith('.lst'): # Convert to string
with open(input_folder, 'r') as file:
files = [line.strip() for line in file if line.strip()]
else:
files = list(Path(input_folder).glob('*.npy'))
print(f"found {len(files)} npy.")
instrumental_files = {}
vocal_files = {}
for file in files:
if not isinstance(file, Path):
file = Path(file)
name = file.stem
if 'instrumental' in name.lower():
base_name = name.lower().replace('instrumental', '')#.strip('_')
instrumental_files[base_name] = file
elif 'vocal' in name.lower():
# base_name = name.lower().replace('vocal', '').strip('_')
last_index = name.lower().rfind('vocal')
if last_index != -1:
# Create a new string with the last 'vocal' removed
base_name = name.lower()[:last_index] + name.lower()[last_index + len('vocal'):]
else:
base_name = name.lower()
vocal_files[base_name] = file
# Find matching pairs
pairs = []
for base_name in instrumental_files.keys():
if base_name in vocal_files:
pairs.append((
instrumental_files[base_name],
vocal_files[base_name],
base_name
))
return pairs
def main():
parser = argparse.ArgumentParser(description='High fidelity neural audio codec using Vocos decoder.')
parser.add_argument('--input_folder', type=Path, required=True, help='Input folder containing NPY files.')
parser.add_argument('--output_base', type=Path, required=True, help='Base output folder.')
parser.add_argument('--resume_path', type=str, default='./final_ckpt/ckpt_00360000.pth', help='Path to model checkpoint.')
parser.add_argument('--config_path', type=str, default='./config.yaml', help='Path to Vocos config file.')
parser.add_argument('--vocal_decoder_path', type=str, default='/aifs4su/mmcode/codeclm/xcodec_mini_infer_newdecoder/decoders/decoder_131000.pth', help='Path to Vocos decoder weights.')
parser.add_argument('--inst_decoder_path', type=str, default='/aifs4su/mmcode/codeclm/xcodec_mini_infer_newdecoder/decoders/decoder_151000.pth', help='Path to Vocos decoder weights.')
parser.add_argument('-r', '--rescale', action='store_true', help='Rescale output to avoid clipping.')
args = parser.parse_args()
# Validate inputs
if not args.input_folder.exists():
sys.exit(f"Input folder {args.input_folder} does not exist.")
if not os.path.isfile(args.config_path):
sys.exit(f"{args.config_path} file does not exist.")
# if not os.path.isfile(args.decoder_path):
# sys.exit(f"{args.decoder_path} file does not exist.")
# Create output directories
mix_dir = args.output_base / 'mix'
stems_dir = args.output_base / 'stems'
os.makedirs(mix_dir, exist_ok=True)
os.makedirs(stems_dir, exist_ok=True)
# Initialize models
config_ss = OmegaConf.load("./final_ckpt/config.yaml")
soundstream = build_soundstream_model(config_ss)
parameter_dict = torch.load(args.resume_path)
soundstream.load_state_dict(parameter_dict['codec_model'])
soundstream.eval()
vocal_decoder, inst_decoder = build_codec_model(args.config_path, args.vocal_decoder_path, args.inst_decoder_path)
# Find and process matching pairs
pairs = find_matching_pairs(args.input_folder)
print(f"Found {len(pairs)} matching pairs")
pairs = [p for p in pairs if not os.path.exists(mix_dir / f'{p[2]}.mp3')]
print(f"{len(pairs)} to reconstruct...")
for instrumental_file, vocal_file, base_name in tqdm(pairs):
print(f"\nProcessing pair: {base_name}")
# Create stems directory for this song
song_stems_dir = stems_dir / base_name
os.makedirs(song_stems_dir, exist_ok=True)
try:
# Process instrumental
instrumental_output = process_audio(
instrumental_file,
song_stems_dir / 'instrumental.mp3',
args.rescale,
args,
inst_decoder,
soundstream
)
# Process vocal
vocal_output = process_audio(
vocal_file,
song_stems_dir / 'vocal.mp3',
args.rescale,
args,
vocal_decoder,
soundstream
)
except IndexError as e:
print(e)
continue
# Create and save mix
try:
mix_output = instrumental_output + vocal_output
save_audio(mix_output, mix_dir / f'{base_name}.mp3', 44100, args.rescale)
print(f"Created mix: {mix_dir / f'{base_name}.mp3'}")
except RuntimeError as e:
print(e)
print(f"mix {base_name} failed! inst: {instrumental_output.shape}, vocal: {vocal_output.shape}")
if __name__ == '__main__':
main()
# Example Usage
# python reconstruct_separately.py --input_folder test_samples --output_base test