English
music
emotion
kjysmu commited on
Commit
1adb3ce
·
verified ·
1 Parent(s): 47851eb

Upload 6 files

Browse files
preprocess/.DS_Store ADDED
Binary file (6.15 kB). View file
 
preprocess/encoder/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ "Import all submodules"
2
+
3
+ # from model import
preprocess/encoder/mert.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import Wav2Vec2FeatureExtractor, AutoModel
4
+
5
+ class FeatureExtractorMERT:
6
+ def __init__(self, model_name="m-a-p/MERT-v1-95M", device_id=0, sr=24000):
7
+ self.model_name = model_name
8
+ self.device_id = device_id
9
+ self.sr = sr
10
+ self.device = torch.device(f"cuda:{self.device_id}" if torch.cuda.is_available() else "cpu")
11
+ self.model = AutoModel.from_pretrained(self.model_name, trust_remote_code=True).to(self.device)
12
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained(self.model_name, trust_remote_code=True)
13
+
14
+ def extract_features_from_segment(self, segment, sample_rate, save_path):
15
+ input_audio = segment.float()
16
+ model_inputs = self.processor(input_audio, sampling_rate=sample_rate, return_tensors="pt")
17
+ model_inputs = model_inputs.to(self.device)
18
+
19
+ with torch.no_grad():
20
+ model_outputs = self.model(**model_inputs, output_hidden_states=True)
21
+
22
+ # Stack and process hidden states
23
+ all_layer_hidden_states = torch.stack(model_outputs.hidden_states).squeeze()[1:, :, :].unsqueeze(0)
24
+ all_layer_hidden_states = all_layer_hidden_states.mean(dim=2)
25
+ features = all_layer_hidden_states.cpu().detach().numpy()
26
+
27
+ # Save features
28
+ np.save(save_path, features)
preprocess/encoder/music2latent.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import torchaudio.transforms as T
5
+ import numpy as np
6
+ from music2latent import EncoderDecoder # Import your custom model
7
+
8
+ class FeatureExtractorM2L:
9
+ def __init__(self, device_id=0, sr=44100):
10
+ self.device_id = device_id
11
+ self.sr = sr
12
+ self.device = torch.device(f"cuda:{self.device_id}" if torch.cuda.is_available() else "cpu")
13
+ self.model = EncoderDecoder(device=self.device)
14
+
15
+ def extract_features_from_segment(self, segment, sample_rate, save_path):
16
+ input_audio = segment.unsqueeze(0).to(self.device) # Add batch dimension and move to the device
17
+
18
+ with torch.no_grad():
19
+ model_outputs = self.model.encode(input_audio, extract_features=True)
20
+
21
+ features = model_outputs.mean(dim=-1).cpu().numpy()
22
+ np.save(save_path, features)
23
+
24
+
25
+
26
+
27
+
28
+
29
+
preprocess/feature_extractor.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import torchaudio.transforms as T
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ from omegaconf import DictConfig
8
+ import hydra
9
+ from hydra.utils import to_absolute_path
10
+ from transformers import Wav2Vec2FeatureExtractor, AutoModel
11
+
12
+ from encoder.mert import FeatureExtractorMERT
13
+ from encoder.music2latent import FeatureExtractorM2L
14
+
15
+ class AudioProcessor:
16
+ def __init__(self, cfg: DictConfig):
17
+ self.input_directory = cfg.dataset.input_dir
18
+ self.output_directory = cfg.dataset.output_dir
19
+ self.segment_duration = cfg.segment_duration
20
+ self.resample_rate = cfg.model.sr
21
+ self.device_id = cfg.device_id
22
+ self.feature_extractor = self._initialize_extractor(cfg.model.name)
23
+ self.is_split = cfg.is_split
24
+
25
+ def _initialize_extractor(self, model_name: str):
26
+ if "MERT" in model_name:
27
+ return FeatureExtractorMERT(model_name=model_name, device_id=self.device_id, sr=self.resample_rate)
28
+ elif "music2latent" == model_name:
29
+ return FeatureExtractorM2L(device_id=self.device_id, sr=self.resample_rate)
30
+ else:
31
+ raise NotImplementedError(f"Feature extraction for model {model_name} is not implemented.")
32
+
33
+ def resample_waveform(self, waveform, original_sample_rate, target_sample_rate):
34
+ if original_sample_rate != target_sample_rate:
35
+ resampler = T.Resample(original_sample_rate, target_sample_rate)
36
+ return resampler(waveform), target_sample_rate
37
+ return waveform, original_sample_rate
38
+
39
+ def split_audio(self, waveform, sample_rate):
40
+ segment_samples = self.segment_duration * sample_rate
41
+ total_samples = waveform.size(0)
42
+
43
+ segments = []
44
+ for start in range(0, total_samples, segment_samples):
45
+ end = start + segment_samples
46
+ if end <= total_samples:
47
+ segment = waveform[start:end]
48
+ segments.append(segment)
49
+
50
+ # In case audio length is shorter than segment length.
51
+ if len(segments) == 0:
52
+ segment = waveform
53
+ segments.append(segment)
54
+
55
+ return segments
56
+
57
+ def process_audio_file(self, file_path, output_dir):
58
+ print(f"Processing {file_path}")
59
+ waveform, sample_rate = torchaudio.load(file_path)
60
+
61
+ if waveform.shape[0] > 1:
62
+ waveform = waveform.mean(dim=0).unsqueeze(0)
63
+ waveform = waveform.squeeze()
64
+ waveform, sample_rate = self.resample_waveform(waveform, sample_rate, self.resample_rate)
65
+
66
+ if self.is_split:
67
+ segments = self.split_audio(waveform, sample_rate)
68
+ for i, segment in enumerate(segments):
69
+ segment_save_path = os.path.join(output_dir, f"segment_{i}.npy")
70
+ if os.path.exists(segment_save_path):
71
+ continue
72
+ self.feature_extractor.extract_features_from_segment(segment, sample_rate, segment_save_path)
73
+ else:
74
+ segment_save_path = os.path.join(output_dir, f"segment_0.npy")
75
+ if not os.path.exists(segment_save_path):
76
+ self.feature_extractor.extract_features_from_segment(waveform, sample_rate, segment_save_path)
77
+
78
+ def process_directory(self):
79
+ for root, _, files in os.walk(self.input_directory):
80
+ for file in files:
81
+ if file.endswith('.mp3'):
82
+ file_path = os.path.join(root, file)
83
+ relative_path = os.path.relpath(file_path, self.input_directory)
84
+ output_file_dir = os.path.join(self.output_directory, os.path.splitext(relative_path)[0])
85
+ os.makedirs(output_file_dir, exist_ok=True)
86
+ self.process_audio_file(file_path, output_file_dir)
87
+
88
+ @hydra.main(version_base=None, config_path="../config", config_name="prep_config")
89
+ def main(cfg: DictConfig):
90
+ processor = AudioProcessor(cfg)
91
+ processor.process_directory()
92
+
93
+ if __name__ == "__main__":
94
+ main()
preprocess/jamendo_split.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import pickle
4
+ import numpy as np
5
+ import fire
6
+ from collections import Counter
7
+
8
+ class Split:
9
+ def read_tsv(self, fn):
10
+ r = []
11
+ with open(fn) as tsv:
12
+ reader = csv.reader(tsv, delimiter='\t')
13
+ for row in reader:
14
+ r.append(row)
15
+ return r[1:]
16
+
17
+ def get_tag_list(self, option):
18
+ if option == 'top50tags':
19
+ tag_list = np.load('dataset/jamendo/meta/tag_list_50.npy')
20
+ else:
21
+ tag_list = np.load('dataset/jamendo/meta/tag_list.npy')
22
+ if option == 'genre':
23
+ tag_list = tag_list[:87]
24
+ elif option == 'instrument':
25
+ tag_list = tag_list[87:127]
26
+ elif option == 'moodtheme':
27
+ tag_list = tag_list[127:]
28
+ return list(tag_list)
29
+
30
+ def get_npy_array(self, path, tag_list, option, type_='train'):
31
+ if option=='all':
32
+ tsv_fn = os.path.join(path, 'autotagging-%s.tsv'%type_)
33
+ else:
34
+ tsv_fn = os.path.join(path, 'autotagging_%s-%s.tsv'%(option, type_))
35
+ rows = self.read_tsv(tsv_fn)
36
+ dictionary = {}
37
+ i = 0
38
+ for row in rows:
39
+ temp_dict = {}
40
+ temp_dict['path'] = row[3]
41
+ temp_dict['duration'] = (float(row[4]) * 12000 - 512) // 256
42
+ if option == 'all':
43
+ temp_dict['tags'] = np.zeros(183)
44
+ elif option == 'genre':
45
+ temp_dict['tags'] = np.zeros(87)
46
+ elif option == 'instrument':
47
+ temp_dict['tags'] = np.zeros(40)
48
+ elif option == 'moodtheme':
49
+ temp_dict['tags'] = np.zeros(56)
50
+ elif option == 'top50tags':
51
+ temp_dict['tags'] = np.zeros(50)
52
+ tags = row[5:]
53
+ for tag in tags:
54
+ try:
55
+ temp_dict['tags'][tag_list.index(tag)] = 1
56
+ except:
57
+ continue
58
+ if temp_dict['tags'].sum() > 0 and os.path.exists(os.path.join(self.npy_path, row[3][:-3])+'npy'):
59
+ dictionary[i] = temp_dict
60
+ i += 1
61
+ dict_fn = os.path.join(path, '%s_%s_dict.pickle'%(option, type_))
62
+ with open(dict_fn, 'wb') as pf:
63
+ pickle.dump(dictionary, pf)
64
+
65
+ def run_iter(self, split, option='all'):
66
+ tag_list = self.get_tag_list(option)
67
+ path = 'dataset/jamendo/splits/split-%d/' % split
68
+ self.get_npy_array(path, tag_list, option, type_='train')
69
+ self.get_npy_array(path, tag_list, option, type_='validation')
70
+ self.get_npy_array(path, tag_list, option, type_='test')
71
+
72
+ def run(self, path):
73
+ self.npy_path = path
74
+ for i in range(5):
75
+ # self.run_iter(i, 'all')
76
+ self.run_iter(i, 'genre')
77
+ self.run_iter(i, 'instrument')
78
+ self.run_iter(i, 'moodtheme')
79
+ # self.run_iter(i, 'top50tags')
80
+
81
+ if __name__ == '__main__':
82
+ s = Split()
83
+ fire.Fire({'run': s.run})