File size: 3,887 Bytes
2b5b9ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import sys
import numpy as np
import torch
import logging
import pandas as pd
import glob
logger = logging.getLogger(f'main.{__name__}')

sys.path.insert(0, '.')  # nopep8

class JoinManifestSpecs(torch.utils.data.Dataset):
    def __init__(self, split, main_spec_dir_path,other_spec_dir_path, mel_num=None, spec_crop_len=None,pad_value=-5,**kwargs):
        super().__init__()
        self.main_prob = 0.5
        self.split = split
        self.batch_max_length = spec_crop_len
        self.batch_min_length = 50
        self.mel_num = mel_num
        self.pad_value = pad_value
        manifest_files = []
        for dir_path in main_spec_dir_path.split(','):
            manifest_files += glob.glob(f'{dir_path}/*.tsv')
        df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
        self.df_main = pd.concat(df_list,ignore_index=True)

        manifest_files = []
        for dir_path in other_spec_dir_path.split(','):
            manifest_files += glob.glob(f'{dir_path}/*.tsv')
        df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
        self.df_other = pd.concat(df_list,ignore_index=True)

        if split == 'train':
            self.dataset = self.df_main.iloc[100:]
        elif split == 'valid' or split == 'val':
            self.dataset = self.df_main.iloc[:100]
        elif split == 'test':
            self.df_main = self.add_name_num(self.df_main)
            self.dataset = self.df_main
        else:
            raise ValueError(f'Unknown split {split}')
        self.dataset.reset_index(inplace=True)
        print('dataset len:', len(self.dataset))

    def add_name_num(self,df):
        """each file may have different caption, we add num to filename to identify each audio-caption pair"""
        name_count_dict = {}
        change = []
        for t in df.itertuples():
            name = getattr(t,'name')
            if name in name_count_dict:
                name_count_dict[name] += 1
            else:
                name_count_dict[name] = 0
            change.append((t[0],name_count_dict[name]))
        for t in change:
            df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
        return df

    def __getitem__(self, idx):
        if np.random.uniform(0,1) < self.main_prob:
            data = self.dataset.iloc[idx]
            ori_caption = data['ori_cap']
            struct_caption = data['caption']
        else:
            randidx = np.random.randint(0,len(self.df_other))
            data = self.df_other.iloc[randidx]
            ori_caption = data['caption']
            struct_caption = f'<{ori_caption}, all>'
        item = {}
        try:
            spec = np.load(data['mel_path']) # mel spec [80, 624]
        except:
            mel_path = data['mel_path']
            print(f'corrupted:{mel_path}')
            spec = np.ones((self.mel_num,self.batch_max_length)).astype(np.float32)*self.pad_value
        
        if spec.shape[1] <= self.batch_max_length:
            spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1])),mode='constant',constant_values = (self.pad_value,self.pad_value)) # [80, 624]

        item['image'] = spec[:self.mel_num,:self.batch_max_length]
        item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption}
        if self.split == 'test':
            item['f_name'] = data['name']
        return item

    def __len__(self):
        return len(self.dataset)


class JoinSpecsTrain(JoinManifestSpecs):
    def __init__(self, specs_dataset_cfg):
        super().__init__('train', **specs_dataset_cfg)

class JoinSpecsValidation(JoinManifestSpecs):
    def __init__(self, specs_dataset_cfg):
        super().__init__('valid', **specs_dataset_cfg)

class JoinSpecsTest(JoinManifestSpecs):
    def __init__(self, specs_dataset_cfg):
        super().__init__('test', **specs_dataset_cfg)