from ldm.data.preprocess.NAT_mel import MelNet import os from tqdm import tqdm from glob import glob import math import pandas as pd import logging import math import audioread from tqdm.contrib.concurrent import process_map import torch import torch.nn as nn import torchaudio import numpy as np from torch.distributed import init_process_group from torch.utils.data import Dataset,DataLoader,DistributedSampler import torch.multiprocessing as mp from argparse import Namespace from multiprocessing import Pool import json class tsv_dataset(Dataset): def __init__(self,tsv_path,sr,mode='none',hop_size = None,target_mel_length = None) -> None: super().__init__() if os.path.isdir(tsv_path): files = glob(os.path.join(tsv_path,'*.tsv')) df = pd.concat([pd.read_csv(file,sep='\t') for file in files]) else: df = pd.read_csv(tsv_path,sep='\t') self.audio_paths = [] self.sr = sr self.mode = mode self.target_mel_length = target_mel_length self.hop_size = hop_size for t in tqdm(df.itertuples()): self.audio_paths.append(getattr(t,'audio_path')) def __len__(self): return len(self.audio_paths) def pad_wav(self,wav): # wav should be in shape(1,wav_len) wav_length = wav.shape[-1] assert wav_length > 100, "wav is too short, %s" % wav_length segment_length = (self.target_mel_length + 1) * self.hop_size # final mel will crop the last mel, mel = mel[:,:-1] if segment_length is None or wav_length == segment_length: return wav elif wav_length > segment_length: return wav[:,:segment_length] elif wav_length < segment_length: temp_wav = torch.zeros((1, segment_length),dtype=torch.float32) temp_wav[:, :wav_length] = wav return temp_wav def __getitem__(self, index): audio_path = self.audio_paths[index] wav, orisr = torchaudio.load(audio_path) if wav.shape[0] != 1: # stereo to mono (2,wav_len) -> (1,wav_len) wav = wav.mean(0,keepdim=True) wav = torchaudio.functional.resample(wav, orig_freq=orisr, new_freq=self.sr) if self.mode == 'pad': assert self.target_mel_length is not None wav = self.pad_wav(wav) return audio_path,wav def process_audio_by_tsv(rank,args): if args.num_gpus > 1: init_process_group(backend=args.dist_config['dist_backend'], init_method=args.dist_config['dist_url'], world_size=args.dist_config['world_size'] * args.num_gpus, rank=rank) sr = args.audio_sample_rate dataset = tsv_dataset(args.tsv_path,sr = sr,mode=args.mode,hop_size=args.hop_size,target_mel_length=args.batch_max_length) sampler = DistributedSampler(dataset,shuffle=False) if args.num_gpus > 1 else None # batch_size must == 1,since wav_len is not equal loader = DataLoader(dataset, sampler=sampler,batch_size=1, num_workers=16,drop_last=False) device = torch.device('cuda:{:d}'.format(rank)) mel_net = MelNet(args.__dict__) mel_net.to(device) # if args.num_gpus > 1: # RuntimeError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient. # mel_net = DistributedDataParallel(mel_net, device_ids=[rank]).to(device) loader = tqdm(loader) if rank == 0 else loader for batch in loader: audio_paths,wavs = batch wavs = wavs.to(device) if args.save_resample: for audio_path,wav in zip(audio_paths,wavs): psplits = audio_path.split('/') root,wav_name = psplits[0],psplits[-1] # save resample resample_root,resample_name = root+f'_{sr}',wav_name[:-4]+'_audio.npy' resample_dir_name = os.path.join(resample_root,*psplits[1:-1]) resample_path = os.path.join(resample_dir_name,resample_name) os.makedirs(resample_dir_name,exist_ok=True) np.save(resample_path,wav.cpu().numpy().squeeze(0)) if args.save_mel: mode = args.mode batch_max_length = args.batch_max_length for audio_path,wav in zip(audio_paths,wavs): psplits = audio_path.split('/') root,wav_name = psplits[0],psplits[-1] mel_root,mel_name = root+f'_mel{mode}{sr}nfft{args.fft_size}',wav_name[:-4]+'_mel.npy' mel_dir_name = os.path.join(mel_root,*psplits[1:-1]) mel_path = os.path.join(mel_dir_name,mel_name) if not os.path.exists(mel_path): mel_spec = mel_net(wav).cpu().numpy().squeeze(0) # (mel_bins,mel_len) if mel_spec.shape[1] <= batch_max_length: if mode == 'tile': # pad is done in dataset as pad wav n_repeat = math.ceil((batch_max_length + 1) / mel_spec.shape[1]) mel_spec = np.tile(mel_spec,reps=(1,n_repeat)) elif mode == 'none' or mode == 'pad': pass else: raise ValueError(f'mode:{mode} is not supported') mel_spec = mel_spec[:,:batch_max_length] os.makedirs(mel_dir_name,exist_ok=True) np.save(mel_path,mel_spec) def split_list(i_list,num): each_num = math.ceil(i_list / num) result = [] for i in range(num): s = each_num * i e = (each_num * (i+1)) result.append(i_list[s:e]) return result def drop_bad_wav(item): index,path = item try: with audioread.audio_open(path) as f: totalsec = f.duration if totalsec < 0.1: return index # index except: print(f"corrupted wav:{path}") return index return False def drop_bad_wavs(tsv_path):# 'audioset.csv' df = pd.read_csv(tsv_path,sep='\t') item_list = [] for item in tqdm(df.itertuples()): item_list.append((item[0],getattr(item,'audio_path'))) r = process_map(drop_bad_wav,item_list,max_workers=16,chunksize=16) bad_indices = list(filter(lambda x:x!= False,r)) print(bad_indices) with open('bad_wavs.json','w') as f: x = [item_list[i] for i in bad_indices] json.dump(x,f) df = df.drop(bad_indices,axis=0) df.to_csv(tsv_path,sep='\t',index=False) if __name__ == '__main__': logging.basicConfig(filename='example.log', level=logging.INFO, format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p') tsv_path = './musiccap.tsv' if os.path.isdir(tsv_path): files = glob(os.path.join(tsv_path,'*.tsv')) for file in files: drop_bad_wavs(file) else: drop_bad_wavs(tsv_path) num_gpus = 1 args = { 'audio_sample_rate': 16000, 'audio_num_mel_bins':80, 'fft_size': 1024,# 4000:512 ,16000:1024, 'win_size': 1024, 'hop_size': 256, 'fmin': 0, 'fmax': 8000, 'batch_max_length': 1560, # 4000:312 (nfft = 512,hoplen=128,mellen = 313), 16000:624 , 22050:848 # 'tsv_path': tsv_path, 'num_gpus': num_gpus, 'mode': 'none', 'save_resample':False, 'save_mel' :True } args = Namespace(**args) args.dist_config = { "dist_backend": "nccl", "dist_url": "tcp://localhost:54189", "world_size": 1 } if args.num_gpus>1: mp.spawn(process_audio_by_tsv,nprocs=args.num_gpus,args=(args,)) else: process_audio_by_tsv(0,args=args) print("done")