Spaces:
Sleeping
Sleeping
import sys | |
import numpy as np | |
import torch | |
import logging | |
import pandas as pd | |
import glob | |
logger = logging.getLogger(f'main.{__name__}') | |
sys.path.insert(0, '.') # nopep8 | |
class JoinManifestSpecs(torch.utils.data.Dataset): | |
def __init__(self, split, main_spec_dir_path,other_spec_dir_path, mel_num=None, spec_crop_len=None,pad_value=-5,**kwargs): | |
super().__init__() | |
self.main_prob = 0.5 | |
self.split = split | |
self.batch_max_length = spec_crop_len | |
self.batch_min_length = 50 | |
self.mel_num = mel_num | |
self.pad_value = pad_value | |
manifest_files = [] | |
for dir_path in main_spec_dir_path.split(','): | |
manifest_files += glob.glob(f'{dir_path}/*.tsv') | |
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files] | |
self.df_main = pd.concat(df_list,ignore_index=True) | |
manifest_files = [] | |
for dir_path in other_spec_dir_path.split(','): | |
manifest_files += glob.glob(f'{dir_path}/*.tsv') | |
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files] | |
self.df_other = pd.concat(df_list,ignore_index=True) | |
if split == 'train': | |
self.dataset = self.df_main.iloc[100:] | |
elif split == 'valid' or split == 'val': | |
self.dataset = self.df_main.iloc[:100] | |
elif split == 'test': | |
self.df_main = self.add_name_num(self.df_main) | |
self.dataset = self.df_main | |
else: | |
raise ValueError(f'Unknown split {split}') | |
self.dataset.reset_index(inplace=True) | |
print('dataset len:', len(self.dataset)) | |
def add_name_num(self,df): | |
"""each file may have different caption, we add num to filename to identify each audio-caption pair""" | |
name_count_dict = {} | |
change = [] | |
for t in df.itertuples(): | |
name = getattr(t,'name') | |
if name in name_count_dict: | |
name_count_dict[name] += 1 | |
else: | |
name_count_dict[name] = 0 | |
change.append((t[0],name_count_dict[name])) | |
for t in change: | |
df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}' | |
return df | |
def __getitem__(self, idx): | |
if np.random.uniform(0,1) < self.main_prob: | |
data = self.dataset.iloc[idx] | |
ori_caption = data['ori_cap'] | |
struct_caption = data['caption'] | |
else: | |
randidx = np.random.randint(0,len(self.df_other)) | |
data = self.df_other.iloc[randidx] | |
ori_caption = data['caption'] | |
struct_caption = f'<{ori_caption}, all>' | |
item = {} | |
try: | |
spec = np.load(data['mel_path']) # mel spec [80, 624] | |
except: | |
mel_path = data['mel_path'] | |
print(f'corrupted:{mel_path}') | |
spec = np.ones((self.mel_num,self.batch_max_length)).astype(np.float32)*self.pad_value | |
if spec.shape[1] <= self.batch_max_length: | |
spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1])),mode='constant',constant_values = (self.pad_value,self.pad_value)) # [80, 624] | |
item['image'] = spec[:self.mel_num,:self.batch_max_length] | |
item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption} | |
if self.split == 'test': | |
item['f_name'] = data['name'] | |
return item | |
def __len__(self): | |
return len(self.dataset) | |
class JoinSpecsTrain(JoinManifestSpecs): | |
def __init__(self, specs_dataset_cfg): | |
super().__init__('train', **specs_dataset_cfg) | |
class JoinSpecsValidation(JoinManifestSpecs): | |
def __init__(self, specs_dataset_cfg): | |
super().__init__('valid', **specs_dataset_cfg) | |
class JoinSpecsTest(JoinManifestSpecs): | |
def __init__(self, specs_dataset_cfg): | |
super().__init__('test', **specs_dataset_cfg) | |