KdaiP's picture
Upload 238 files
d358e26 verified
raw
history blame
3.14 kB
import glob
import os
from tqdm import tqdm
from dataclasses import dataclass
import torch
from torch import Tensor
from torch.multiprocessing import Pool, set_start_method
import torchaudio
from config import MelConfig, TrainConfig
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
@dataclass
class DataConfig:
audio_dir = './audios' # path to audios
output_dir = './vocos_datasets' # path to save processed audios
filelist_path = './filelists/filelist.txt' # path to save filelist
data_config = DataConfig()
train_config = TrainConfig()
mel_config = MelConfig()
audio_dir = data_config.audio_dir
output_dir = data_config.output_dir
filelist_path = data_config.filelist_path
segment_size = train_config.segment_size
output_audio_dir = os.path.join(output_dir, 'audios')
# Ensure output directories exist
os.makedirs(output_audio_dir, exist_ok=True)
os.makedirs(os.path.dirname(filelist_path), exist_ok=True)
def load_and_resample_audio(audio_path, target_sr, segment_size, device='cpu') -> Tensor:
try:
y, sr = torchaudio.load(audio_path)
except Exception as e:
print(str(e))
return None
y.to(device)
# Convert to mono
if y.size(0) > 1:
y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time]
# resample audio to target sample_rate
if sr != target_sr:
y = torchaudio.functional.resample(y, sr, target_sr)
if y.size(-1) < segment_size:
y = torch.nn.functional.pad(y, (0, segment_size - y.size(-1)), "constant", 0)
return y
def find_audio_files(directory) -> list:
extensions = ['wav', 'mp3', 'flac']
files_found = []
for extension in extensions:
pattern = os.path.join(directory, '**', f'*.{extension}')
files_found.extend(glob.glob(pattern, recursive=True))
return files_found
@ torch.inference_mode()
def process_audio(audio_path):
audio = load_and_resample_audio(audio_path, mel_config.sample_rate, segment_size, device=device) # shape: [1, time]
if audio is not None:
# get output path
audio_name, _ = os.path.splitext(os.path.basename(audio_path))
output_audio_path = os.path.join(output_audio_dir, audio_name + '.wav')
# save resampled audio and mel features
torchaudio.save(output_audio_path, audio.cpu(), mel_config.sample_rate)
return output_audio_path
def main():
set_start_method('spawn') # CUDA must use spawn method
audio_files = find_audio_files(audio_dir)
results = []
with Pool(processes=8) as pool:
for result in tqdm(pool.imap(process_audio, audio_files), total=len(audio_files)):
if result is not None:
results.append(f'{result}\n')
# save filelist
with open(filelist_path, 'w', encoding='utf-8') as f:
f.writelines(results)
print(f"filelist file has been saved to {filelist_path}")
# faster and use much less CPU
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == '__main__':
main()