jadechoghari commited on
Commit
5085882
·
verified ·
1 Parent(s): 4743cf5

add qa files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. audioldm_train/.DS_Store +0 -0
  2. audioldm_train/__init__.py +1 -0
  3. audioldm_train/__pycache__/__init__.cpython-310.pyc +0 -0
  4. audioldm_train/__pycache__/conditional_models.cpython-310.pyc +0 -0
  5. audioldm_train/__pycache__/dataset_plugin.cpython-310.pyc +0 -0
  6. audioldm_train/conditional_models.py +1354 -0
  7. audioldm_train/config/mos_as_token/qa_mdt.yaml +169 -0
  8. audioldm_train/dataset_plugin.py +508 -0
  9. audioldm_train/losses/__init__.py +1 -0
  10. audioldm_train/losses/__pycache__/__init__.cpython-310.pyc +0 -0
  11. audioldm_train/losses/__pycache__/contperceptual.cpython-310.pyc +0 -0
  12. audioldm_train/losses/contperceptual.py +160 -0
  13. audioldm_train/modules/.DS_Store +0 -0
  14. audioldm_train/modules/__init__.py +0 -0
  15. audioldm_train/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  16. audioldm_train/modules/audiomae/AudioMAE.py +151 -0
  17. audioldm_train/modules/audiomae/README.md +24 -0
  18. audioldm_train/modules/audiomae/__init__.py +0 -0
  19. audioldm_train/modules/audiomae/__pycache__/AudioMAE.cpython-310.pyc +0 -0
  20. audioldm_train/modules/audiomae/__pycache__/__init__.cpython-310.pyc +0 -0
  21. audioldm_train/modules/audiomae/__pycache__/models_mae.cpython-310.pyc +0 -0
  22. audioldm_train/modules/audiomae/__pycache__/models_vit.cpython-310.pyc +0 -0
  23. audioldm_train/modules/audiomae/audiovisual_dataset.py +256 -0
  24. audioldm_train/modules/audiomae/example.py +52 -0
  25. audioldm_train/modules/audiomae/models_mae.py +615 -0
  26. audioldm_train/modules/audiomae/models_vit.py +252 -0
  27. audioldm_train/modules/audiomae/sequence_gen/__init__.py +2 -0
  28. audioldm_train/modules/audiomae/sequence_gen/__pycache__/__init__.cpython-310.pyc +0 -0
  29. audioldm_train/modules/audiomae/sequence_gen/__pycache__/model.cpython-310.pyc +0 -0
  30. audioldm_train/modules/audiomae/sequence_gen/__pycache__/sequence_input.cpython-310.pyc +0 -0
  31. audioldm_train/modules/audiomae/sequence_gen/model.py +329 -0
  32. audioldm_train/modules/audiomae/sequence_gen/sequence_input.py +737 -0
  33. audioldm_train/modules/audiomae/util/__pycache__/patch_embed.cpython-310.pyc +0 -0
  34. audioldm_train/modules/audiomae/util/__pycache__/pos_embed.cpython-310.pyc +0 -0
  35. audioldm_train/modules/audiomae/util/crop.py +43 -0
  36. audioldm_train/modules/audiomae/util/datasets.py +67 -0
  37. audioldm_train/modules/audiomae/util/lars.py +60 -0
  38. audioldm_train/modules/audiomae/util/lr_decay.py +78 -0
  39. audioldm_train/modules/audiomae/util/lr_sched.py +28 -0
  40. audioldm_train/modules/audiomae/util/misc.py +454 -0
  41. audioldm_train/modules/audiomae/util/patch_embed.py +127 -0
  42. audioldm_train/modules/audiomae/util/pos_embed.py +205 -0
  43. audioldm_train/modules/audiomae/util/stat.py +77 -0
  44. audioldm_train/modules/clap/__init__.py +0 -0
  45. audioldm_train/modules/clap/__pycache__/__init__.cpython-310.pyc +0 -0
  46. audioldm_train/modules/clap/open_clip/__init__.py +25 -0
  47. audioldm_train/modules/clap/open_clip/__pycache__/__init__.cpython-310.pyc +0 -0
  48. audioldm_train/modules/clap/open_clip/__pycache__/__init__.cpython-38.pyc +0 -0
  49. audioldm_train/modules/clap/open_clip/__pycache__/factory.cpython-310.pyc +0 -0
  50. audioldm_train/modules/clap/open_clip/__pycache__/factory.cpython-38.pyc +0 -0
audioldm_train/.DS_Store ADDED
Binary file (6.15 kB). View file
 
audioldm_train/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import utilities
audioldm_train/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (169 Bytes). View file
 
audioldm_train/__pycache__/conditional_models.cpython-310.pyc ADDED
Binary file (29.2 kB). View file
 
audioldm_train/__pycache__/dataset_plugin.cpython-310.pyc ADDED
Binary file (10.9 kB). View file
 
audioldm_train/conditional_models.py ADDED
@@ -0,0 +1,1354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ sys.path.append("src")
4
+ import torch
5
+ import logging
6
+ import torch.nn as nn
7
+ from audioldm_train.modules.clap.open_clip import create_model
8
+ from audioldm_train.modules.clap.training.data import get_audio_features
9
+
10
+ import torchaudio
11
+ from transformers import (
12
+ RobertaTokenizer,
13
+ AutoTokenizer,
14
+ T5EncoderModel,
15
+ MT5EncoderModel,
16
+ )
17
+ import torch.nn.functional as F
18
+ from audioldm_train.modules.audiomae.AudioMAE import Vanilla_AudioMAE
19
+ from audioldm_train.modules.phoneme_encoder.encoder import TextEncoder
20
+
21
+ from transformers import SpeechT5Processor, AutoTokenizer, GPT2Model, GPT2Tokenizer
22
+ from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithTextPrenet
23
+
24
+ from audioldm_train.modules.audiomae.sequence_gen.model import CLAP2AudioMAE
25
+ from audioldm_train.modules.audiomae.sequence_gen.sequence_input import (
26
+ Sequence2AudioMAE,
27
+ )
28
+ import numpy as np
29
+ from audioldm_train.modules.audiomae.sequence_gen.model import Prenet
30
+ import json
31
+ with open('offset_pretrained_checkpoints.json', 'r') as config_file:
32
+ config_data = json.load(config_file)
33
+
34
+ """
35
+ The model forward function can return three types of data:
36
+ 1. tensor: used directly as conditioning signal
37
+ 2. dict: where there is a main key as condition, there are also other key that you can use to pass loss function and itermediate result. etc.
38
+ 3. list: the length is 2, in which the first element is tensor, the second element is attntion mask.
39
+
40
+ The output shape for the cross attention condition should be:
41
+ x,x_mask = [bs, seq_len, emb_dim], [bs, seq_len]
42
+
43
+ All the returned data, in which will be used as diffusion input, will need to be in float type
44
+ """
45
+
46
+
47
+ class GPT2WordEmbedding(nn.Module):
48
+ def __init__(self):
49
+ super().__init__()
50
+ # self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
51
+ self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
52
+ self.tokenizer.pad_token = self.tokenizer.eos_token
53
+ self.model = GPT2Model.from_pretrained("gpt2").wte
54
+ self.device = None
55
+
56
+ def get_unconditional_condition(self, batchsize):
57
+ unconditional_condition = ["random"] * batchsize
58
+ return self(unconditional_condition)
59
+
60
+ def forward(self, text):
61
+ assert isinstance(text, list)
62
+ if self.device is None:
63
+ self.device = next(self.model.parameters()).device
64
+
65
+ tokenization_result = self.tokenizer(text, return_tensors="pt", padding=True)
66
+ input_ids, attn_mask = tokenization_result["input_ids"].to(
67
+ self.device
68
+ ), tokenization_result["attention_mask"].to(self.device)
69
+
70
+ input_embed = self.model(input_ids.long())
71
+
72
+ return [input_embed, attn_mask]
73
+
74
+
75
+ class ConcateBandWidthCond(nn.Module):
76
+ def __init__(self, latent_t_size, latent_f_size):
77
+ super().__init__()
78
+ self.placeholder = nn.Linear(1, 1)
79
+ self.latent_t_size = latent_t_size
80
+ self.latent_f_size = latent_f_size
81
+ self.device = None
82
+
83
+ def get_unconditional_condition(self, batchsize):
84
+ return torch.zeros((batchsize, self.latent_t_size, self.latent_f_size)).to(
85
+ self.device
86
+ )
87
+
88
+ def forward(self, mel_spec_bandwidth_cond_extra_channel):
89
+ if self.device is None:
90
+ self.device = mel_spec_bandwidth_cond_extra_channel.device
91
+
92
+ return mel_spec_bandwidth_cond_extra_channel
93
+
94
+
95
+ class BandwidthEncoder(nn.Module):
96
+ def __init__(self):
97
+ super().__init__()
98
+ self.emb = nn.Embedding(1000, 128)
99
+ nn.init.normal_(self.emb.weight, 0.0, 128**-0.5)
100
+ self.linear_bandwidth = nn.Linear(128, 128)
101
+ self.unconditional_condition = torch.zeros((1, 256))
102
+ self.device = None
103
+
104
+ def get_unconditional_condition(self, batchsize):
105
+ return self.unconditional_condition.expand(batchsize, 256)
106
+
107
+ def forward(self, bandwidth):
108
+
109
+ if self.device is None:
110
+ self.device = next(self.linear_bandwidth.parameters()).device
111
+ self.unconditional_condition = self.unconditional_condition.to(self.device)
112
+
113
+ # freq_energy_percentile
114
+ lower_cutoff, higher_cutoff = bandwidth[..., 0], bandwidth[..., 1]
115
+ # lower_cutoff, higher_cutoff = lower_cutoff*0+5, higher_cutoff*0+300
116
+
117
+ lower_cutoff_emb = self.linear_bandwidth(self.emb(lower_cutoff.long()))
118
+ higher_cutoff_emb = self.linear_bandwidth(self.emb(higher_cutoff.long()))
119
+ cutoff_emb = torch.cat([lower_cutoff_emb, higher_cutoff_emb], dim=-1)
120
+ # [bs, 256]
121
+ return cutoff_emb
122
+
123
+
124
+ class SpeechT5TextEncoder(nn.Module):
125
+ def __init__(self):
126
+ super().__init__()
127
+ self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
128
+ self.model = SpeechT5EncoderWithTextPrenet.from_pretrained(
129
+ "microsoft/speecht5_tts"
130
+ )
131
+ for p in self.model.parameters():
132
+ p.requires_grad = False
133
+ self.model.eval()
134
+
135
+ # Required
136
+ def get_unconditional_condition(self, batchsize):
137
+ device = self.model.device
138
+ hidden_state = torch.zeros((batchsize, 1, 768)).to(device)
139
+ attention_mask = torch.ones((batchsize, 1)).to(device)
140
+ return [hidden_state.float(), attention_mask.float()]
141
+
142
+ def forward(self, text):
143
+ with torch.no_grad():
144
+ device = self.model.device
145
+ inputs = self.processor(text=text, return_tensors="pt", padding=True)
146
+ input_ids, attention_mask = inputs["input_ids"].to(device), inputs[
147
+ "attention_mask"
148
+ ].to(device)
149
+ emb = self.model(input_ids, attention_mask)
150
+ emb = emb.last_hidden_state.detach()
151
+ return [emb.float(), attention_mask.float()]
152
+
153
+
154
+ class PhonemeEncoder(nn.Module):
155
+ def __init__(self, vocabs_size=41, pad_length=250, pad_token_id=None):
156
+ super().__init__()
157
+ """
158
+ encoder = PhonemeEncoder(40)
159
+ data = torch.randint(0, 39, (2, 250))
160
+ output = encoder(data)
161
+ import ipdb;ipdb.set_trace()
162
+ """
163
+ assert pad_token_id is not None
164
+
165
+ self.device = None
166
+ self.PAD_LENGTH = int(pad_length)
167
+ self.pad_token_id = pad_token_id
168
+ self.pad_token_sequence = torch.tensor([self.pad_token_id] * self.PAD_LENGTH)
169
+
170
+ self.text_encoder = TextEncoder(
171
+ n_vocab=vocabs_size,
172
+ out_channels=192,
173
+ hidden_channels=192,
174
+ filter_channels=768,
175
+ n_heads=2,
176
+ n_layers=6,
177
+ kernel_size=3,
178
+ p_dropout=0.1,
179
+ )
180
+
181
+ self.learnable_positional_embedding = torch.nn.Parameter(
182
+ torch.zeros((1, 192, self.PAD_LENGTH))
183
+ ) # [batchsize, seqlen, padlen]
184
+ self.learnable_positional_embedding.requires_grad = True
185
+
186
+ # Required
187
+ def get_unconditional_condition(self, batchsize):
188
+ unconditional_tokens = self.pad_token_sequence.expand(
189
+ batchsize, self.PAD_LENGTH
190
+ )
191
+ return self(unconditional_tokens) # Need to return float type
192
+
193
+ # def get_unconditional_condition(self, batchsize):
194
+
195
+ # hidden_state = torch.zeros((batchsize, self.PAD_LENGTH, 192)).to(self.device)
196
+ # attention_mask = torch.ones((batchsize, self.PAD_LENGTH)).to(self.device)
197
+ # return [hidden_state, attention_mask] # Need to return float type
198
+
199
+ def _get_src_mask(self, phoneme):
200
+ src_mask = phoneme != self.pad_token_id
201
+ return src_mask
202
+
203
+ def _get_src_length(self, phoneme):
204
+ src_mask = self._get_src_mask(phoneme)
205
+ length = torch.sum(src_mask, dim=-1)
206
+ return length
207
+
208
+ # def make_empty_condition_unconditional(self, src_length, text_emb, attention_mask):
209
+ # # src_length: [bs]
210
+ # # text_emb: [bs, 192, pad_length]
211
+ # # attention_mask: [bs, pad_length]
212
+ # mask = src_length[..., None, None] > 1
213
+ # text_emb = text_emb * mask
214
+
215
+ # attention_mask[src_length < 1] = attention_mask[src_length < 1] * 0.0 + 1.0
216
+ # return text_emb, attention_mask
217
+
218
+ def forward(self, phoneme_idx):
219
+ if self.device is None:
220
+ self.device = self.learnable_positional_embedding.device
221
+ self.pad_token_sequence = self.pad_token_sequence.to(self.device)
222
+
223
+ src_length = self._get_src_length(phoneme_idx)
224
+ text_emb, m, logs, text_emb_mask = self.text_encoder(phoneme_idx, src_length)
225
+ text_emb = text_emb + self.learnable_positional_embedding
226
+
227
+ # text_emb, text_emb_mask = self.make_empty_condition_unconditional(src_length, text_emb, text_emb_mask)
228
+
229
+ return [
230
+ text_emb.permute(0, 2, 1),
231
+ text_emb_mask.squeeze(1),
232
+ ] # [2, 250, 192], [2, 250]
233
+
234
+
235
+ class FlanT5HiddenState(nn.Module):
236
+ """
237
+ llama = FlanT5HiddenState()
238
+ data = ["","this is not an empty sentence"]
239
+ encoder_hidden_states = llama(data)
240
+ import ipdb;ipdb.set_trace()
241
+ """
242
+
243
+ def __init__(
244
+ self, text_encoder_name=config_data['flan_t5'], freeze_text_encoder=True
245
+ ):
246
+ super().__init__()
247
+ self.freeze_text_encoder = freeze_text_encoder
248
+ ## MODIFIED
249
+ self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
250
+ self.model = T5EncoderModel.from_pretrained("google/flan-t5-large")
251
+ if freeze_text_encoder:
252
+ self.model.eval()
253
+ for p in self.model.parameters():
254
+ p.requires_grad = False
255
+ else:
256
+ print("=> The text encoder is learnable")
257
+
258
+ self.empty_hidden_state_cfg = None
259
+ self.device = None
260
+
261
+ # Required
262
+ def get_unconditional_condition(self, batchsize):
263
+ param = next(self.model.parameters())
264
+ if self.freeze_text_encoder:
265
+ assert param.requires_grad == False
266
+
267
+ # device = param.device
268
+ if self.empty_hidden_state_cfg is None:
269
+ self.empty_hidden_state_cfg, _ = self([""])
270
+
271
+ hidden_state = torch.cat([self.empty_hidden_state_cfg] * batchsize).float()
272
+ attention_mask = (
273
+ torch.ones((batchsize, hidden_state.size(1)))
274
+ .to(hidden_state.device)
275
+ .float()
276
+ )
277
+ return [hidden_state, attention_mask] # Need to return float type
278
+
279
+ def forward(self, batch):
280
+ param = next(self.model.parameters())
281
+ if self.freeze_text_encoder:
282
+ assert param.requires_grad == False
283
+
284
+ if self.device is None:
285
+ self.device = param.device
286
+
287
+ # print("Manually change text")
288
+ # for i in range(len(batch)):
289
+ # batch[i] = "dog barking"
290
+ try:
291
+ return self.encode_text(batch)
292
+ except Exception as e:
293
+ print(e, batch)
294
+ logging.exception("An error occurred: %s", str(e))
295
+
296
+ def encode_text(self, prompt):
297
+ device = self.model.device
298
+ batch = self.tokenizer(
299
+ prompt,
300
+ max_length=128, # self.tokenizer.model_max_length
301
+ padding=True,
302
+ truncation=True,
303
+ return_tensors="pt",
304
+ )
305
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
306
+ device
307
+ )
308
+ # Get text encoding
309
+ if self.freeze_text_encoder:
310
+ with torch.no_grad():
311
+ encoder_hidden_states = self.model(
312
+ input_ids=input_ids, attention_mask=attention_mask
313
+ )[0]
314
+ else:
315
+ encoder_hidden_states = self.model(
316
+ input_ids=input_ids, attention_mask=attention_mask
317
+ )[0]
318
+ return [
319
+ encoder_hidden_states.detach(),
320
+ attention_mask.float(),
321
+ ] # Attention mask == 1 means usable token
322
+
323
+
324
+ class FlanT5HiddenStatePaddedSameLength(nn.Module):
325
+ """
326
+ llama = FlanT5HiddenState()
327
+ data = ["","this is not an empty sentence"]
328
+ encoder_hidden_states = llama(data)
329
+ import ipdb;ipdb.set_trace()
330
+ """
331
+
332
+ def __init__(
333
+ self, text_encoder_name="google/flan-t5-large", freeze_text_encoder=True
334
+ ):
335
+ super().__init__()
336
+ self.freeze_text_encoder = freeze_text_encoder
337
+ self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
338
+ self.model = T5EncoderModel.from_pretrained("google/flan-t5-large")
339
+ if freeze_text_encoder:
340
+ self.model.eval()
341
+ for p in self.model.parameters():
342
+ p.requires_grad = False
343
+ else:
344
+ print("=> The text encoder is learnable")
345
+
346
+ self.empty_hidden_state_cfg = None
347
+ self.device = None
348
+
349
+ # Required
350
+ def get_unconditional_condition(self, batchsize):
351
+ param = next(self.model.parameters())
352
+ if self.freeze_text_encoder:
353
+ assert param.requires_grad == False
354
+
355
+ # device = param.device
356
+ if self.empty_hidden_state_cfg is None:
357
+ self.empty_hidden_state_cfg, _ = self([""])
358
+
359
+ hidden_state = torch.cat([self.empty_hidden_state_cfg] * batchsize).float()
360
+ attention_mask = (
361
+ torch.ones((batchsize, hidden_state.size(1)))
362
+ .to(hidden_state.device)
363
+ .float()
364
+ )
365
+ return [hidden_state, attention_mask] # Need to return float type
366
+
367
+ def forward(self, batch):
368
+ param = next(self.model.parameters())
369
+ if self.freeze_text_encoder:
370
+ assert param.requires_grad == False
371
+
372
+ if self.device is None:
373
+ self.device = param.device
374
+
375
+ # print("Manually change text")
376
+ # for i in range(len(batch)):
377
+ # batch[i] = "dog barking"
378
+ try:
379
+ text_embed = self.encode_text(batch)
380
+ return text_embed
381
+ except Exception as e:
382
+ print(e, batch)
383
+ logging.exception("An error occurred: %s", str(e))
384
+
385
+ def encode_text(self, prompt):
386
+ device = self.model.device
387
+ batch = self.tokenizer(
388
+ prompt,
389
+ max_length=128,
390
+ padding="max_length",
391
+ truncation=True,
392
+ return_tensors="pt",
393
+ )
394
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
395
+ device
396
+ )
397
+
398
+ # Get text encoding
399
+ if self.freeze_text_encoder:
400
+ with torch.no_grad():
401
+ encoder_hidden_states = self.model(
402
+ input_ids=input_ids, attention_mask=attention_mask
403
+ )[0]
404
+ else:
405
+ encoder_hidden_states = self.model(
406
+ input_ids=input_ids, attention_mask=attention_mask
407
+ )[0]
408
+ return [
409
+ encoder_hidden_states.detach(),
410
+ attention_mask.float(),
411
+ ] # Attention mask == 1 means usable token
412
+
413
+
414
+ class CLAPGenAudioMAECond(CLAP2AudioMAE):
415
+ def __init__(
416
+ self,
417
+ cond_stage_config,
418
+ learnable=True,
419
+ pretrained_path=None,
420
+ use_gt_mae_output=None, # False: does not use AudioMAE GT, True: Use AudioMAE GT
421
+ use_gt_mae_prob=None,
422
+ ): # The prob of using AudioMAE GT
423
+ super().__init__(base_learning_rate=1e-5, cond_stage_config=cond_stage_config)
424
+ assert use_gt_mae_output is not None and use_gt_mae_prob is not None
425
+
426
+ if pretrained_path is not None:
427
+ print("Reload CLAPGenAudioMAECond from %s" % pretrained_path)
428
+ state_dict = torch.load(pretrained_path)["state_dict"]
429
+ self.load_state_dict(state_dict)
430
+
431
+ self.use_gt_mae_output = use_gt_mae_output
432
+ self.use_gt_mae_prob = use_gt_mae_prob
433
+ self.learnable = learnable
434
+
435
+ if not learnable:
436
+ # Only optimize the GPT2 model
437
+ for p in self.model.parameters():
438
+ p.requires_grad = False
439
+ self.eval()
440
+
441
+ # Required
442
+ def get_unconditional_condition(self, batchsize):
443
+ return_dict = self.cfg_uncond(batchsize)
444
+ return return_dict
445
+
446
+ def forward(self, batch):
447
+ # The conditional module can return both tensor or dictionaries
448
+ # The returned tensor will be corresponding to the cond_stage_key
449
+ # The returned dict will have keys that correspond to the cond_stage_key
450
+ ret_dict = {}
451
+ if self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob:
452
+ cond_dict = self.get_input(batch)
453
+ # Used as condition
454
+ ret_dict["crossattn_clap_to_audiomae_feature"] = [
455
+ cond_dict["crossattn_audiomae_pooled"][0],
456
+ torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float(),
457
+ ] # Input sequence and mask
458
+ else:
459
+ # Used as condition
460
+ input_embeds, cond_dict = self.generate(batch)
461
+ input_embeds_mask = (
462
+ torch.ones((input_embeds.size(0), input_embeds.size(1)))
463
+ .to(input_embeds.device)
464
+ .float()
465
+ )
466
+ ret_dict["crossattn_clap_to_audiomae_feature"] = [
467
+ input_embeds,
468
+ input_embeds_mask,
469
+ ] # Input sequence and mask
470
+
471
+ # If the following two keys are not in cond_stage_key, then they will not be used as condition
472
+ ret_dict["film_clap_cond1"] = cond_dict[
473
+ "film_clap_cond1"
474
+ ] # the clap target latent
475
+ ret_dict["crossattn_audiomae_pooled"] = cond_dict[
476
+ "crossattn_audiomae_pooled"
477
+ ] # audiomae target latent
478
+
479
+ if self.learnable and self.training:
480
+ loss = self.training_step(batch, cond_dict=cond_dict)
481
+ ret_dict["noncond_loss_clap2audiomae"] = loss
482
+
483
+ return ret_dict
484
+
485
+
486
+ class SequenceGenAudioMAECond(Sequence2AudioMAE):
487
+ def __init__(
488
+ self,
489
+ cond_stage_config,
490
+ base_learning_rate,
491
+ sequence_gen_length,
492
+ sequence_input_key,
493
+ sequence_input_embed_dim,
494
+ batchsize,
495
+ always_output_audiomae_gt=False,
496
+ pretrained_path=None,
497
+ force_reload_pretrain_avoid_overwrite=False,
498
+ learnable=True,
499
+ use_warmup=True,
500
+ use_gt_mae_output=None, # False: does not use AudioMAE GT, True: Use AudioMAE GT
501
+ use_gt_mae_prob=None,
502
+ ): # The prob of using AudioMAE GT
503
+ if use_warmup:
504
+ print(
505
+ "Warning: You didn't initialize sequence prediction module with trainer. Set warmup to False. You can still use the warmup scheme from the latent diffusion model."
506
+ )
507
+ use_warmup = False
508
+
509
+ super().__init__(
510
+ base_learning_rate=base_learning_rate,
511
+ cond_stage_config=cond_stage_config,
512
+ sequence_gen_length=sequence_gen_length,
513
+ sequence_input_key=sequence_input_key,
514
+ use_warmup=use_warmup,
515
+ sequence_input_embed_dim=sequence_input_embed_dim,
516
+ batchsize=batchsize,
517
+ )
518
+
519
+ assert use_gt_mae_output is not None and use_gt_mae_prob is not None
520
+ self.always_output_audiomae_gt = always_output_audiomae_gt
521
+ self.force_reload_pretrain_avoid_overwrite = (
522
+ force_reload_pretrain_avoid_overwrite
523
+ )
524
+ self.pretrained_path = pretrained_path
525
+ if self.force_reload_pretrain_avoid_overwrite:
526
+ self.is_reload = False
527
+ else:
528
+ self.is_reload = True
529
+
530
+ self.load_pretrain_model()
531
+
532
+ self.use_gt_mae_output = use_gt_mae_output
533
+ self.use_gt_mae_prob = use_gt_mae_prob
534
+ self.learnable = learnable
535
+
536
+ if not learnable:
537
+ # Only optimize the GPT2 model
538
+ for p in self.model.parameters():
539
+ p.requires_grad = False
540
+ self.eval()
541
+
542
+ def load_pretrain_model(self):
543
+ if self.pretrained_path is not None:
544
+ print("Reload SequenceGenAudioMAECond from %s" % self.pretrained_path)
545
+ state_dict = torch.load(self.pretrained_path)["state_dict"]
546
+ self.load_state_dict(state_dict)
547
+
548
+ # Required
549
+ def get_unconditional_condition(self, batchsize):
550
+ return_dict = self.cfg_uncond(batchsize)
551
+ return_dict["crossattn_audiomae_generated"] = [
552
+ return_dict["crossattn_audiomae_pooled"][0],
553
+ torch.ones_like(return_dict["crossattn_audiomae_pooled"][1]).float(),
554
+ ]
555
+ return return_dict
556
+
557
+ def forward(self, batch):
558
+ # The conditional module can return both tensor or dictionaries
559
+ # The returned tensor will be corresponding to the cond_stage_key
560
+ # The returned dict will have keys that correspond to the cond_stage_key
561
+ ret_dict = {}
562
+
563
+ if self.force_reload_pretrain_avoid_overwrite and not self.is_reload:
564
+ self.load_pretrain_model()
565
+ self.is_reload = True
566
+
567
+ self.check_module_param_update()
568
+
569
+ if self.always_output_audiomae_gt or (
570
+ self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob
571
+ ):
572
+ cond_dict = self.get_input(batch)
573
+ ret_dict["crossattn_audiomae_generated"] = [
574
+ cond_dict["crossattn_audiomae_pooled"][0],
575
+ torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float(),
576
+ ] # Input sequence and mask
577
+ # _, output = self.training_step(batch, cond_dict=cond_dict, return_output=True)
578
+ # ret_dict["crossattn_audiomae_generated"] = [output, torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float()] # Input sequence and mask
579
+ else:
580
+ if not self.training:
581
+ print("--------------> Generate !!!!!!!!!!!!")
582
+ input_embeds, cond_dict = self.generate(batch)
583
+ # print("Generate Partial!!!!"); input_embeds, cond_dict = self.generate_partial(batch)
584
+ input_embeds_mask = (
585
+ torch.ones((input_embeds.size(0), input_embeds.size(1)))
586
+ .to(input_embeds.device)
587
+ .float()
588
+ )
589
+ ret_dict["crossattn_audiomae_generated"] = [
590
+ input_embeds,
591
+ input_embeds_mask,
592
+ ] # Input sequence and mask
593
+
594
+ # If the following two keys are not in cond_stage_key, then they will not be used as condition
595
+ for key in cond_dict.keys():
596
+ ret_dict[key] = cond_dict[key]
597
+
598
+ if self.learnable and self.training:
599
+ loss = self.training_step(batch, cond_dict=cond_dict)
600
+ ret_dict["noncond_loss_clap2audiomae"] = loss
601
+
602
+ return ret_dict
603
+
604
+
605
+ class SequenceGenAudioMAECond_AudioMAE_PostNet(Sequence2AudioMAE):
606
+ def __init__(
607
+ self,
608
+ cond_stage_config,
609
+ base_learning_rate,
610
+ sequence_gen_length,
611
+ sequence_input_key,
612
+ sequence_input_embed_dim,
613
+ batchsize,
614
+ always_output_audiomae_gt=False,
615
+ pretrained_path=None,
616
+ use_ar_gen_loss=False,
617
+ force_reload_pretrain_avoid_overwrite=False,
618
+ learnable=True,
619
+ use_warmup=True,
620
+ use_gt_mae_output=None, # False: does not use AudioMAE GT, True: Use AudioMAE GT
621
+ use_gt_mae_prob=None,
622
+ ): # The prob of using AudioMAE GT
623
+ if use_warmup:
624
+ print(
625
+ "Warning: You didn't initialize sequence prediction module with trainer. Set warmup to False. You can still use the warmup scheme from the latent diffusion model."
626
+ )
627
+ use_warmup = False
628
+
629
+ super().__init__(
630
+ base_learning_rate=base_learning_rate,
631
+ cond_stage_config=cond_stage_config,
632
+ sequence_gen_length=sequence_gen_length,
633
+ sequence_input_key=sequence_input_key,
634
+ use_ar_gen_loss=use_ar_gen_loss,
635
+ use_warmup=use_warmup,
636
+ sequence_input_embed_dim=sequence_input_embed_dim,
637
+ batchsize=batchsize,
638
+ )
639
+
640
+ assert use_gt_mae_output is not None and use_gt_mae_prob is not None
641
+ self.always_output_audiomae_gt = always_output_audiomae_gt
642
+ self.force_reload_pretrain_avoid_overwrite = (
643
+ force_reload_pretrain_avoid_overwrite
644
+ )
645
+ self.pretrained_path = pretrained_path
646
+ if self.force_reload_pretrain_avoid_overwrite:
647
+ self.is_reload = False
648
+ else:
649
+ self.is_reload = True
650
+
651
+ self.load_pretrain_model()
652
+
653
+ self.prenet = Prenet(in_dim=768, sizes=[768, 768, 768], dropout_rate=0.5)
654
+
655
+ self.use_gt_mae_output = use_gt_mae_output
656
+ self.use_gt_mae_prob = use_gt_mae_prob
657
+ self.learnable = learnable
658
+
659
+ if not learnable:
660
+ # Only optimize the GPT2 model
661
+ for p in self.model.parameters():
662
+ p.requires_grad = False
663
+ self.eval()
664
+
665
+ def load_pretrain_model(self):
666
+ if self.pretrained_path is not None:
667
+ print("Reload SequenceGenAudioMAECond from %s" % self.pretrained_path)
668
+ state_dict = torch.load(self.pretrained_path)["state_dict"]
669
+ self.load_state_dict(state_dict)
670
+
671
+ # Required
672
+ def get_unconditional_condition(self, batchsize):
673
+ return_dict = self.cfg_uncond(batchsize)
674
+ return_dict["crossattn_audiomae_generated"] = [
675
+ return_dict["crossattn_audiomae_pooled"][0],
676
+ torch.ones_like(return_dict["crossattn_audiomae_pooled"][1]).float(),
677
+ ]
678
+ return return_dict
679
+
680
+ def forward(self, batch):
681
+ # The conditional module can return both tensor or dictionaries
682
+ # The returned tensor will be corresponding to the cond_stage_key
683
+ # The returned dict will have keys that correspond to the cond_stage_key
684
+ ret_dict = {}
685
+
686
+ if self.force_reload_pretrain_avoid_overwrite and not self.is_reload:
687
+ self.load_pretrain_model()
688
+ self.is_reload = True
689
+
690
+ self.check_module_param_update()
691
+
692
+ if self.always_output_audiomae_gt or (
693
+ self.use_gt_mae_output and torch.rand(1).item() < self.use_gt_mae_prob
694
+ ):
695
+ cond_dict = self.get_input(batch)
696
+ gt_audiomae = self.prenet(cond_dict["crossattn_audiomae_pooled"][0])
697
+ ret_dict["crossattn_audiomae_generated"] = [
698
+ gt_audiomae,
699
+ torch.ones_like(cond_dict["crossattn_audiomae_pooled"][1]).float(),
700
+ ] # Input sequence and mask
701
+ else:
702
+ print("--------------> Generate!!!!!!!!!!!!")
703
+ input_embeds, cond_dict = self.generate(batch)
704
+ # input_embeds, cond_dict = self.generate_partial(batch)
705
+ input_embeds = self.prenet(input_embeds)
706
+ input_embeds_mask = (
707
+ torch.ones((input_embeds.size(0), input_embeds.size(1)))
708
+ .to(input_embeds.device)
709
+ .float()
710
+ )
711
+ ret_dict["crossattn_audiomae_generated"] = [
712
+ input_embeds,
713
+ input_embeds_mask,
714
+ ] # Input sequence and mask
715
+
716
+ # If the following two keys are not in cond_stage_key, then they will not be used as condition
717
+ for key in cond_dict.keys():
718
+ ret_dict[key] = cond_dict[key]
719
+
720
+ if self.learnable and self.training:
721
+ loss = self.training_step(batch, cond_dict=cond_dict)
722
+ ret_dict["noncond_loss_clap2audiomae"] = loss
723
+
724
+ return ret_dict
725
+
726
+
727
+ class AudioMAEConditionCTPoolRandTFSeparated(nn.Module):
728
+ """
729
+ audiomae = AudioMAEConditionCTPool2x2()
730
+ data = torch.randn((4, 1024, 128))
731
+ output = audiomae(data)
732
+ import ipdb;ipdb.set_trace()
733
+ exit(0)
734
+ """
735
+
736
+ def __init__(
737
+ self,
738
+ time_pooling_factors=[1, 2, 4, 8],
739
+ freq_pooling_factors=[1, 2, 4, 8],
740
+ eval_time_pooling=None,
741
+ eval_freq_pooling=None,
742
+ mask_ratio=0.0,
743
+ regularization=False,
744
+ no_audiomae_mask=True,
745
+ no_audiomae_average=False,
746
+ ):
747
+ super().__init__()
748
+ self.device = None
749
+ self.time_pooling_factors = time_pooling_factors
750
+ self.freq_pooling_factors = freq_pooling_factors
751
+ self.no_audiomae_mask = no_audiomae_mask
752
+ self.no_audiomae_average = no_audiomae_average
753
+
754
+ self.eval_freq_pooling = eval_freq_pooling
755
+ self.eval_time_pooling = eval_time_pooling
756
+ self.mask_ratio = mask_ratio
757
+ self.use_reg = regularization
758
+
759
+ self.audiomae = Vanilla_AudioMAE()
760
+ self.audiomae.eval()
761
+ for p in self.audiomae.parameters():
762
+ p.requires_grad = False
763
+
764
+ # Required
765
+ def get_unconditional_condition(self, batchsize):
766
+ param = next(self.audiomae.parameters())
767
+ assert param.requires_grad == False
768
+ device = param.device
769
+ # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors)
770
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
771
+ self.eval_freq_pooling, 8
772
+ )
773
+ # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))]
774
+ # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
775
+ token_num = int(512 / (time_pool * freq_pool))
776
+ return [
777
+ torch.zeros((batchsize, token_num, 768)).to(device).float(),
778
+ torch.ones((batchsize, token_num)).to(device).float(),
779
+ ]
780
+
781
+ def pool(self, representation, time_pool=None, freq_pool=None):
782
+ assert representation.size(-1) == 768
783
+ representation = representation[:, 1:, :].transpose(1, 2)
784
+ bs, embedding_dim, token_num = representation.size()
785
+ representation = representation.reshape(bs, embedding_dim, 64, 8)
786
+
787
+ if self.training:
788
+ if time_pool is None and freq_pool is None:
789
+ time_pool = min(
790
+ 64,
791
+ self.time_pooling_factors[
792
+ np.random.choice(list(range(len(self.time_pooling_factors))))
793
+ ],
794
+ )
795
+ freq_pool = min(
796
+ 8,
797
+ self.freq_pooling_factors[
798
+ np.random.choice(list(range(len(self.freq_pooling_factors))))
799
+ ],
800
+ )
801
+ # freq_pool = min(8, time_pool) # TODO here I make some modification.
802
+ else:
803
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
804
+ self.eval_freq_pooling, 8
805
+ )
806
+
807
+ self.avgpooling = nn.AvgPool2d(
808
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
809
+ )
810
+ self.maxpooling = nn.MaxPool2d(
811
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
812
+ )
813
+
814
+ pooled = (
815
+ self.avgpooling(representation) + self.maxpooling(representation)
816
+ ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num]
817
+ pooled = pooled.flatten(2).transpose(1, 2)
818
+ return pooled # [bs, token_num, embedding_dim]
819
+
820
+ def regularization(self, x):
821
+ assert x.size(-1) == 768
822
+ x = F.normalize(x, p=2, dim=-1)
823
+ return x
824
+
825
+ # Required
826
+ def forward(self, batch, time_pool=None, freq_pool=None):
827
+ assert batch.size(-2) == 1024 and batch.size(-1) == 128
828
+
829
+ if self.device is None:
830
+ self.device = batch.device
831
+
832
+ batch = batch.unsqueeze(1)
833
+ with torch.no_grad():
834
+ representation = self.audiomae(
835
+ batch,
836
+ mask_ratio=self.mask_ratio,
837
+ no_mask=self.no_audiomae_mask,
838
+ no_average=self.no_audiomae_average,
839
+ )
840
+ representation = self.pool(representation, time_pool, freq_pool)
841
+ if self.use_reg:
842
+ representation = self.regularization(representation)
843
+ return [
844
+ representation,
845
+ torch.ones((representation.size(0), representation.size(1)))
846
+ .to(representation.device)
847
+ .float(),
848
+ ]
849
+
850
+
851
+ class AudioMAEConditionCTPoolRand(nn.Module):
852
+ """
853
+ audiomae = AudioMAEConditionCTPool2x2()
854
+ data = torch.randn((4, 1024, 128))
855
+ output = audiomae(data)
856
+ import ipdb;ipdb.set_trace()
857
+ exit(0)
858
+ """
859
+
860
+ def __init__(
861
+ self,
862
+ time_pooling_factors=[1, 2, 4, 8],
863
+ freq_pooling_factors=[1, 2, 4, 8],
864
+ eval_time_pooling=None,
865
+ eval_freq_pooling=None,
866
+ mask_ratio=0.0,
867
+ regularization=False,
868
+ no_audiomae_mask=True,
869
+ no_audiomae_average=False,
870
+ ):
871
+ super().__init__()
872
+ self.device = None
873
+ self.time_pooling_factors = time_pooling_factors
874
+ self.freq_pooling_factors = freq_pooling_factors
875
+ self.no_audiomae_mask = no_audiomae_mask
876
+ self.no_audiomae_average = no_audiomae_average
877
+
878
+ self.eval_freq_pooling = eval_freq_pooling
879
+ self.eval_time_pooling = eval_time_pooling
880
+ self.mask_ratio = mask_ratio
881
+ self.use_reg = regularization
882
+
883
+ self.audiomae = Vanilla_AudioMAE()
884
+ self.audiomae.eval()
885
+ for p in self.audiomae.parameters():
886
+ p.requires_grad = False
887
+
888
+ # Required
889
+ def get_unconditional_condition(self, batchsize):
890
+ param = next(self.audiomae.parameters())
891
+ assert param.requires_grad == False
892
+ device = param.device
893
+ # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors)
894
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
895
+ self.eval_freq_pooling, 8
896
+ )
897
+ # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))]
898
+ # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
899
+ token_num = int(512 / (time_pool * freq_pool))
900
+ return [
901
+ torch.zeros((batchsize, token_num, 768)).to(device).float(),
902
+ torch.ones((batchsize, token_num)).to(device).float(),
903
+ ]
904
+
905
+ def pool(self, representation, time_pool=None, freq_pool=None):
906
+ assert representation.size(-1) == 768
907
+ representation = representation[:, 1:, :].transpose(1, 2)
908
+ bs, embedding_dim, token_num = representation.size()
909
+ representation = representation.reshape(bs, embedding_dim, 64, 8)
910
+
911
+ if self.training:
912
+ if time_pool is None and freq_pool is None:
913
+ time_pool = min(
914
+ 64,
915
+ self.time_pooling_factors[
916
+ np.random.choice(list(range(len(self.time_pooling_factors))))
917
+ ],
918
+ )
919
+ # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
920
+ freq_pool = min(8, time_pool) # TODO here I make some modification.
921
+ else:
922
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
923
+ self.eval_freq_pooling, 8
924
+ )
925
+
926
+ self.avgpooling = nn.AvgPool2d(
927
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
928
+ )
929
+ self.maxpooling = nn.MaxPool2d(
930
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
931
+ )
932
+
933
+ pooled = (
934
+ self.avgpooling(representation) + self.maxpooling(representation)
935
+ ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num]
936
+ pooled = pooled.flatten(2).transpose(1, 2)
937
+ return pooled # [bs, token_num, embedding_dim]
938
+
939
+ def regularization(self, x):
940
+ assert x.size(-1) == 768
941
+ x = F.normalize(x, p=2, dim=-1)
942
+ return x
943
+
944
+ # Required
945
+ def forward(self, batch, time_pool=None, freq_pool=None):
946
+ assert batch.size(-2) == 1024 and batch.size(-1) == 128
947
+
948
+ if self.device is None:
949
+ self.device = batch.device
950
+
951
+ batch = batch.unsqueeze(1)
952
+ with torch.no_grad():
953
+ representation = self.audiomae(
954
+ batch,
955
+ mask_ratio=self.mask_ratio,
956
+ no_mask=self.no_audiomae_mask,
957
+ no_average=self.no_audiomae_average,
958
+ )
959
+ representation = self.pool(representation, time_pool, freq_pool)
960
+ if self.use_reg:
961
+ representation = self.regularization(representation)
962
+ return [
963
+ representation,
964
+ torch.ones((representation.size(0), representation.size(1)))
965
+ .to(representation.device)
966
+ .float(),
967
+ ]
968
+
969
+
970
+ class ConditionalToken(nn.Module):
971
+ def __init__(self, embedding_dim):
972
+ super(ConditionalToken, self).__init__()
973
+ self.embedding_dim = embedding_dim
974
+ # Define the conditional tokens as fixed values
975
+ self.pooling_factor_tokens = {
976
+ 1: torch.Tensor([1.0, 0.0] * (embedding_dim // 2)),
977
+ 2: torch.Tensor([0.0, 1.0] * (embedding_dim // 2)),
978
+ 4: torch.Tensor([1.0, 1.0] * (embedding_dim // 2)),
979
+ 8: torch.Tensor([-1.0, 0.0] * (embedding_dim // 2)),
980
+ 16: torch.Tensor([0.0, -1.0] * (embedding_dim // 2)),
981
+ 32: torch.Tensor([-1.0, -1.0] * (embedding_dim // 2)),
982
+ 64: torch.Tensor([0.0, 0.0] * (embedding_dim // 2)),
983
+ }
984
+ for p in self.parameters():
985
+ p.requires_grad = False
986
+
987
+ def forward(self, condition, batchsize):
988
+ """
989
+ Returns the conditional token for the given condition.
990
+ """
991
+ if condition not in self.pooling_factor_tokens.keys():
992
+ raise ValueError(f"Unsupported condition: {condition}")
993
+ batched_token = self.pooling_factor_tokens[condition][None, None].expand(
994
+ batchsize, 1, self.embedding_dim
995
+ )
996
+ return batched_token
997
+
998
+
999
+ class AudioMAEConditionCTPoolRandV2(nn.Module):
1000
+ """
1001
+ audiomae = AudioMAEConditionCTPool2x2()
1002
+ data = torch.randn((4, 1024, 128))
1003
+ output = audiomae(data)
1004
+ import ipdb;ipdb.set_trace()
1005
+ exit(0)
1006
+ """
1007
+
1008
+ def __init__(
1009
+ self,
1010
+ time_pooling_factors=[1, 2, 4, 8],
1011
+ freq_pooling_factors=[1, 2, 4, 8],
1012
+ eval_time_pooling=None,
1013
+ eval_freq_pooling=None,
1014
+ mask_ratio=0.0,
1015
+ regularization=False,
1016
+ no_audiomae_mask=True,
1017
+ no_audiomae_average=False,
1018
+ ):
1019
+ super().__init__()
1020
+ self.device = None
1021
+ self.time_pooling_factors = time_pooling_factors
1022
+ self.freq_pooling_factors = freq_pooling_factors
1023
+ self.no_audiomae_mask = no_audiomae_mask
1024
+ self.no_audiomae_average = no_audiomae_average
1025
+
1026
+ self.eval_freq_pooling = eval_freq_pooling
1027
+ self.eval_time_pooling = eval_time_pooling
1028
+ self.mask_ratio = mask_ratio
1029
+ self.use_reg = regularization
1030
+
1031
+ self.pooling_tokens = ConditionalToken(768)
1032
+
1033
+ self.audiomae = Vanilla_AudioMAE()
1034
+ self.audiomae.eval()
1035
+
1036
+ for p in self.audiomae.parameters():
1037
+ p.requires_grad = False
1038
+
1039
+ # Required
1040
+ def get_unconditional_condition(self, batchsize):
1041
+ param = next(self.audiomae.parameters())
1042
+ assert param.requires_grad == False
1043
+ device = param.device
1044
+ # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors)
1045
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
1046
+ self.eval_freq_pooling, 8
1047
+ )
1048
+ # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))]
1049
+ # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
1050
+ pool_condition_token = self.pooling_tokens(time_pool, batchsize).to(device)
1051
+ token_num = int(512 / (time_pool * freq_pool))
1052
+
1053
+ rep = torch.zeros((batchsize, token_num, 768)).to(device).float()
1054
+ rep = torch.cat([rep, pool_condition_token], dim=1)
1055
+
1056
+ return [rep, torch.ones((batchsize, token_num + 1)).to(device).float()]
1057
+
1058
+ def pool(self, representation, time_pool=None, freq_pool=None):
1059
+ assert representation.size(-1) == 768
1060
+ representation = representation[:, 1:, :].transpose(1, 2)
1061
+ bs, embedding_dim, token_num = representation.size()
1062
+ representation = representation.reshape(bs, embedding_dim, 64, 8)
1063
+
1064
+ if self.training:
1065
+ if time_pool is None and freq_pool is None:
1066
+ time_pool = min(
1067
+ 64,
1068
+ self.time_pooling_factors[
1069
+ np.random.choice(list(range(len(self.time_pooling_factors))))
1070
+ ],
1071
+ )
1072
+ # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
1073
+ freq_pool = min(8, time_pool) # TODO here I make some modification.
1074
+ else:
1075
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
1076
+ self.eval_freq_pooling, 8
1077
+ )
1078
+
1079
+ self.avgpooling = nn.AvgPool2d(
1080
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
1081
+ )
1082
+ self.maxpooling = nn.MaxPool2d(
1083
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
1084
+ )
1085
+ pooled = (
1086
+ self.avgpooling(representation) + self.maxpooling(representation)
1087
+ ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num]
1088
+ pooled = pooled.flatten(2).transpose(1, 2)
1089
+ return pooled, time_pool, freq_pool # [bs, token_num, embedding_dim]
1090
+
1091
+ def regularization(self, x):
1092
+ assert x.size(-1) == 768
1093
+ x = F.normalize(x, p=2, dim=-1)
1094
+ return x
1095
+
1096
+ # Required
1097
+ def forward(self, batch):
1098
+ assert batch.size(-2) == 1024 and batch.size(-1) == 128
1099
+
1100
+ if self.device is None:
1101
+ self.device = batch.device
1102
+
1103
+ batch = batch.unsqueeze(1)
1104
+
1105
+ with torch.no_grad():
1106
+ representation = self.audiomae(
1107
+ batch,
1108
+ mask_ratio=self.mask_ratio,
1109
+ no_mask=self.no_audiomae_mask,
1110
+ no_average=self.no_audiomae_average,
1111
+ )
1112
+ representation, time_pool, freq_pool = self.pool(representation)
1113
+ if self.use_reg:
1114
+ representation = self.regularization(representation)
1115
+ pool_condition_token = self.pooling_tokens(
1116
+ time_pool, representation.size(0)
1117
+ ).to(representation.device)
1118
+ representation = torch.cat([representation, pool_condition_token], dim=1)
1119
+
1120
+ return [
1121
+ representation,
1122
+ torch.ones((representation.size(0), representation.size(1)))
1123
+ .to(representation.device)
1124
+ .float(),
1125
+ ]
1126
+
1127
+
1128
+ class BeatDownbeatConditionConcat(nn.Module):
1129
+ def __init__(self, latent_t_size, latent_f_size):
1130
+ super().__init__()
1131
+ self.latent_t_size = latent_t_size
1132
+ self.latent_f_size = latent_f_size
1133
+ self.device = None
1134
+
1135
+ # Required
1136
+ def get_unconditional_condition(self, batchsize):
1137
+ return torch.zeros((batchsize, self.latent_t_size, self.latent_f_size)).to(
1138
+ self.device
1139
+ )
1140
+
1141
+ # Required
1142
+ def forward(self, batch):
1143
+ if self.device is None:
1144
+ self.device = batch.device
1145
+ return batch
1146
+
1147
+
1148
+ class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
1149
+ def __init__(
1150
+ self,
1151
+ pretrained_path,
1152
+ sampling_rate=16000,
1153
+ embed_mode="audio",
1154
+ amodel="HTSAT-base",
1155
+ unconditional_prob=0.1,
1156
+ random_mute=False,
1157
+ max_random_mute_portion=0.5,
1158
+ training_mode=True,
1159
+ ):
1160
+ super().__init__()
1161
+ self.device = "cpu"
1162
+ self.precision = "fp32"
1163
+ self.amodel = amodel # or 'PANN-14'
1164
+ self.tmodel = "roberta" # the best text encoder in our training
1165
+ self.enable_fusion = False # False if you do not want to use the fusion model
1166
+ self.fusion_type = "aff_2d"
1167
+ self.pretrained = pretrained_path
1168
+ self.embed_mode = embed_mode
1169
+ self.embed_mode_orig = embed_mode
1170
+ self.sampling_rate = sampling_rate
1171
+ self.unconditional_prob = unconditional_prob
1172
+ self.random_mute = random_mute
1173
+ self.tokenize = RobertaTokenizer.from_pretrained(config_data["roberta-base"])
1174
+ self.max_random_mute_portion = max_random_mute_portion
1175
+ self.training_mode = training_mode
1176
+ self.model, self.model_cfg = create_model(
1177
+ self.amodel,
1178
+ self.tmodel,
1179
+ self.pretrained,
1180
+ precision=self.precision,
1181
+ device=self.device,
1182
+ enable_fusion=self.enable_fusion,
1183
+ fusion_type=self.fusion_type,
1184
+ )
1185
+ audio_cfg = self.model_cfg["audio_cfg"]
1186
+ self.mel_transform = torchaudio.transforms.MelSpectrogram(
1187
+ sample_rate=audio_cfg["sample_rate"],
1188
+ n_fft=audio_cfg["window_size"],
1189
+ win_length=audio_cfg["window_size"],
1190
+ hop_length=audio_cfg["hop_size"],
1191
+ center=True,
1192
+ pad_mode="reflect",
1193
+ power=2.0,
1194
+ norm=None,
1195
+ onesided=True,
1196
+ n_mels=64,
1197
+ f_min=audio_cfg["fmin"],
1198
+ f_max=audio_cfg["fmax"],
1199
+ )
1200
+ for p in self.model.parameters():
1201
+ p.requires_grad = False
1202
+ self.unconditional_token = None
1203
+ self.model.eval()
1204
+
1205
+ def get_unconditional_condition(self, batchsize):
1206
+ self.unconditional_token = self.model.get_text_embedding(
1207
+ self.tokenizer(["", ""])
1208
+ )[0:1]
1209
+ return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
1210
+
1211
+ def batch_to_list(self, batch):
1212
+ ret = []
1213
+ for i in range(batch.size(0)):
1214
+ ret.append(batch[i])
1215
+ return ret
1216
+
1217
+ def make_decision(self, probability):
1218
+ if float(torch.rand(1)) < probability:
1219
+ return True
1220
+ else:
1221
+ return False
1222
+
1223
+ def random_uniform(self, start, end):
1224
+ val = torch.rand(1).item()
1225
+ return start + (end - start) * val
1226
+
1227
+ def _random_mute(self, waveform):
1228
+ # waveform: [bs, t-steps]
1229
+ t_steps = waveform.size(-1)
1230
+ for i in range(waveform.size(0)):
1231
+ mute_size = int(
1232
+ self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
1233
+ )
1234
+ mute_start = int(self.random_uniform(0, t_steps - mute_size))
1235
+ waveform[i, mute_start : mute_start + mute_size] = 0
1236
+ return waveform
1237
+
1238
+ def cos_similarity(self, waveform, text):
1239
+ # waveform: [bs, t_steps]
1240
+ original_embed_mode = self.embed_mode
1241
+ with torch.no_grad():
1242
+ self.embed_mode = "audio"
1243
+ audio_emb = self(waveform.cuda())
1244
+ self.embed_mode = "text"
1245
+ text_emb = self(text)
1246
+ similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
1247
+ self.embed_mode = original_embed_mode
1248
+ return similarity.squeeze()
1249
+
1250
+ def build_unconditional_emb(self):
1251
+ self.unconditional_token = self.model.get_text_embedding(
1252
+ self.tokenizer(["", ""])
1253
+ )[0:1]
1254
+
1255
+ def forward(self, batch):
1256
+ # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0
1257
+ # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0
1258
+ if self.model.training == True and not self.training_mode:
1259
+ print(
1260
+ "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters."
1261
+ )
1262
+ self.model, self.model_cfg = create_model(
1263
+ self.amodel,
1264
+ self.tmodel,
1265
+ self.pretrained,
1266
+ precision=self.precision,
1267
+ device="cuda",
1268
+ enable_fusion=self.enable_fusion,
1269
+ fusion_type=self.fusion_type,
1270
+ )
1271
+ for p in self.model.parameters():
1272
+ p.requires_grad = False
1273
+ self.model.eval()
1274
+
1275
+ if self.unconditional_token is None:
1276
+ self.build_unconditional_emb()
1277
+
1278
+ # if(self.training_mode):
1279
+ # assert self.model.training == True
1280
+ # else:
1281
+ # assert self.model.training == False
1282
+
1283
+ # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
1284
+ if self.embed_mode == "audio":
1285
+ if not self.training:
1286
+ print("INFO: clap model calculate the audio embedding as condition")
1287
+ with torch.no_grad():
1288
+ # assert (
1289
+ # self.sampling_rate == 16000
1290
+ # ), "We only support 16000 sampling rate"
1291
+
1292
+ # if self.random_mute:
1293
+ # batch = self._random_mute(batch)
1294
+ # batch: [bs, 1, t-samples]
1295
+ if self.sampling_rate != 48000:
1296
+ batch = torchaudio.functional.resample(
1297
+ batch, orig_freq=self.sampling_rate, new_freq=48000
1298
+ )
1299
+
1300
+ audio_data = batch.squeeze(1)
1301
+ mel = self.mel_transform(audio_data)
1302
+ audio_dict = get_audio_features(
1303
+ audio_data,
1304
+ mel,
1305
+ 480000,
1306
+ data_truncating="fusion",
1307
+ data_filling="repeatpad",
1308
+ audio_cfg=self.model_cfg["audio_cfg"],
1309
+ )
1310
+ # [bs, 512]
1311
+ embed = self.model.get_audio_embedding(audio_dict)
1312
+ elif self.embed_mode == "text":
1313
+ with torch.no_grad():
1314
+ # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
1315
+ text_data = self.tokenizer(batch)
1316
+
1317
+ if isinstance(batch, str) or (
1318
+ isinstance(batch, list) and len(batch) == 1
1319
+ ):
1320
+ for key in text_data.keys():
1321
+ text_data[key] = text_data[key].unsqueeze(0)
1322
+
1323
+ embed = self.model.get_text_embedding(text_data)
1324
+
1325
+ embed = embed.unsqueeze(1)
1326
+ for i in range(embed.size(0)):
1327
+ if self.make_decision(self.unconditional_prob):
1328
+ embed[i] = self.unconditional_token
1329
+ # embed = torch.randn((batch.size(0), 1, 512)).type_as(batch)
1330
+ return embed.detach()
1331
+
1332
+ def tokenizer(self, text):
1333
+ result = self.tokenize(
1334
+ text,
1335
+ padding="max_length",
1336
+ truncation=True,
1337
+ max_length=512,
1338
+ return_tensors="pt",
1339
+ )
1340
+ return {k: v.squeeze(0) for k, v in result.items()}
1341
+
1342
+
1343
+ if __name__ == "__main__":
1344
+ model = CLAPAudioEmbeddingClassifierFreev2(
1345
+ pretrained_path="/mnt/bn/lqhaoheliu/exps/checkpoints/audioldm/ckpt/CLAP.pt",
1346
+ embed_mode="text",
1347
+ amodel="HTSAT-tiny",
1348
+ )
1349
+ # data = torch.randn((6, 1, int(16000*10.24)))
1350
+ data = ["text", "text"]
1351
+ res = model(data)
1352
+ import ipdb
1353
+
1354
+ ipdb.set_trace()
audioldm_train/config/mos_as_token/qa_mdt.yaml ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_directory: "./log/latent_diffusion"
2
+ project: "audioldm"
3
+ precision: "high"
4
+
5
+ # TODO: change this with your project path
6
+ base_root: "/content/qa-mdt"
7
+
8
+ # TODO: change this with your pretrained path
9
+ # TODO: pretrained path is also needed in "base_root/offset_pretrained_checkpoints.json"
10
+ pretrained:
11
+ clap_music: "/content/qa-mdt/checkpoints/clap_music"
12
+ flan_t5: "/content/qa-mdt/checkpoints/flant5"
13
+ hifi-gan: "/content/qa-mdt/checkpoints/hifi-gan/checkpoints"
14
+ roberta-base: "/content/qa-mdt/checkpoints/robertabase"
15
+
16
+ # TODO: lmdb dataset that stores pMOS of the training dataset
17
+ # while in inference, we don't need it !!!
18
+ # while in inference, we don't need it !!!
19
+ # while in inference, we don't need it !!!
20
+ mos_path: ""
21
+
22
+ train_path:
23
+ train_lmdb_path: [] # path list of training lmdb folders
24
+
25
+ val_path:
26
+ val_lmdb_path: [] # path list of training lmdb folders
27
+ val_key_path: [] # path list of training lmdb key files
28
+
29
+ variables:
30
+ sampling_rate: &sampling_rate 16000
31
+ mel_bins: &mel_bins 64
32
+ latent_embed_dim: &latent_embed_dim 8
33
+ latent_t_size: &latent_t_size 256 # TODO might need to change
34
+ latent_f_size: &latent_f_size 16 # TODO might need to change
35
+ in_channels: &unet_in_channels 8 # TODO might need to change
36
+ optimize_ddpm_parameter: &optimize_ddpm_parameter true
37
+ optimize_gpt: &optimize_gpt true
38
+ warmup_steps: &warmup_steps 2000
39
+
40
+ # we rewrite the dataset so it may not be needed
41
+ data:
42
+ train: ["audiocaps"]
43
+ val: "audiocaps"
44
+ test: "audiocaps"
45
+ class_label_indices: "audioset_eval_subset"
46
+ dataloader_add_ons: ["waveform_rs_48k"]
47
+
48
+ step:
49
+ validation_every_n_epochs: 10000
50
+ save_checkpoint_every_n_steps: 1000
51
+ # limit_val_batches: 2
52
+ max_steps: 8000000
53
+ save_top_k: 1000
54
+
55
+ preprocessing:
56
+ audio:
57
+ sampling_rate: *sampling_rate
58
+ max_wav_value: 32768.0
59
+ duration: 10.24
60
+ stft:
61
+ filter_length: 1024
62
+ hop_length: 160
63
+ win_length: 1024
64
+ mel:
65
+ n_mel_channels: *mel_bins
66
+ mel_fmin: 0
67
+ mel_fmax: 8000
68
+
69
+ augmentation:
70
+ mixup: 0.0
71
+
72
+ model:
73
+ target: audioldm_train.modules.latent_diffusion.ddpm.LatentDiffusion
74
+ params:
75
+ # Autoencoder
76
+ first_stage_config:
77
+ base_learning_rate: 8.0e-06
78
+ target: audioldm_train.modules.latent_encoder.autoencoder.AutoencoderKL
79
+ params:
80
+ # TODO: change it with your VAE checkpoint
81
+ reload_from_ckpt: "/content/qa-mdt/checkpoints/hifi-gan/checkpoints/vae_mel_16k_64bins.ckpt"
82
+ sampling_rate: *sampling_rate
83
+ batchsize: 1
84
+ monitor: val/rec_loss
85
+ image_key: fbank
86
+ subband: 1
87
+ embed_dim: *latent_embed_dim
88
+ time_shuffle: 1
89
+ lossconfig:
90
+ target: audioldm_train.losses.LPIPSWithDiscriminator
91
+ params:
92
+ disc_start: 50001
93
+ kl_weight: 1000.0
94
+ disc_weight: 0.5
95
+ disc_in_channels: 1
96
+ ddconfig:
97
+ double_z: true
98
+ mel_bins: *mel_bins
99
+ z_channels: 8
100
+ resolution: 256
101
+ downsample_time: false
102
+ in_channels: 1
103
+ out_ch: 1
104
+ ch: 128
105
+ ch_mult:
106
+ - 1
107
+ - 2
108
+ - 4
109
+ num_res_blocks: 2
110
+ attn_resolutions: []
111
+ dropout: 0.0
112
+
113
+ # Other parameters
114
+ base_learning_rate: 8.0e-5
115
+ warmup_steps: *warmup_steps
116
+ optimize_ddpm_parameter: *optimize_ddpm_parameter
117
+ sampling_rate: *sampling_rate
118
+ batchsize: 16
119
+ linear_start: 0.0015
120
+ linear_end: 0.0195
121
+ num_timesteps_cond: 1
122
+ log_every_t: 200
123
+ timesteps: 1000
124
+ unconditional_prob_cfg: 0.1
125
+ parameterization: eps # [eps, x0, v]
126
+ first_stage_key: fbank
127
+ latent_t_size: *latent_t_size
128
+ latent_f_size: *latent_f_size
129
+ channels: *latent_embed_dim
130
+ monitor: val/loss_simple_ema
131
+ scale_by_std: true
132
+
133
+ unet_config:
134
+ # TODO: choose your class, Default: MDT_MOS_AS_TOKEN
135
+ # (Noted: the 2D-Rope, SwiGLU and the MDT are in two classes, when training with all of them, they should be changed and merged)
136
+ target: audioldm_train.modules.diffusionmodules.PixArt.PixArt_MDT_MOS_AS_TOKEN
137
+ params:
138
+ input_size : [256, 16]
139
+ # patch_size: [16,4]
140
+ patch_size : [4, 1]
141
+ overlap_size: [0, 0]
142
+ in_channels : 8
143
+ hidden_size : 1152
144
+ depth : 28
145
+ num_heads : 16
146
+ mlp_ratio : 4.0
147
+ class_dropout_prob : 0.1
148
+ pred_sigma : True
149
+ drop_path : 0.
150
+ window_size : 0
151
+ window_block_indexes : None
152
+ use_rel_pos : False
153
+ cond_dim : 1024
154
+ lewei_scale : 1.0
155
+ overlap: [0, 0]
156
+ use_cfg: true
157
+ mask_ratio: 0.30
158
+ decode_layer: 8
159
+
160
+ cond_stage_config:
161
+ crossattn_flan_t5:
162
+ cond_stage_key: text
163
+ conditioning_key: crossattn
164
+ target: audioldm_train.conditional_models.FlanT5HiddenState
165
+
166
+ evaluation_params:
167
+ unconditional_guidance_scale: 3.5
168
+ ddim_sampling_steps: 200
169
+ n_candidates_per_samples: 3
audioldm_train/dataset_plugin.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import torchaudio
5
+ import matplotlib.pyplot as plt
6
+
7
+ CACHE = {
8
+ "get_vits_phoneme_ids": {
9
+ "PAD_LENGTH": 310,
10
+ "_pad": "_",
11
+ "_punctuation": ';:,.!?¡¿—…"«»“” ',
12
+ "_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
13
+ "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ",
14
+ "_special": "♪☎☒☝⚠",
15
+ }
16
+ }
17
+
18
+ CACHE["get_vits_phoneme_ids"]["symbols"] = (
19
+ [CACHE["get_vits_phoneme_ids"]["_pad"]]
20
+ + list(CACHE["get_vits_phoneme_ids"]["_punctuation"])
21
+ + list(CACHE["get_vits_phoneme_ids"]["_letters"])
22
+ + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"])
23
+ + list(CACHE["get_vits_phoneme_ids"]["_special"])
24
+ )
25
+ CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = {
26
+ s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"])
27
+ }
28
+
29
+
30
+ def get_vits_phoneme_ids(config, dl_output, metadata):
31
+ pad_token_id = 0
32
+ pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
33
+ _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
34
+
35
+ assert (
36
+ "phonemes" in metadata.keys()
37
+ ), "You must provide vits phonemes on using addon get_vits_phoneme_ids"
38
+ clean_text = metadata["phonemes"]
39
+ sequence = []
40
+
41
+ for symbol in clean_text:
42
+ symbol_id = _symbol_to_id[symbol]
43
+ sequence += [symbol_id]
44
+
45
+ inserted_zero_sequence = [0] * (len(sequence) * 2)
46
+ inserted_zero_sequence[1::2] = sequence
47
+ inserted_zero_sequence = inserted_zero_sequence + [0]
48
+
49
+ def _pad_phonemes(phonemes_list):
50
+ return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list))
51
+
52
+ return {"phoneme_idx": torch.LongTensor(_pad_phonemes(inserted_zero_sequence))}
53
+
54
+
55
+ def get_vits_phoneme_ids_no_padding(config, dl_output, metadata):
56
+ pad_token_id = 0
57
+ pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
58
+ _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
59
+
60
+ assert (
61
+ "phonemes" in metadata.keys()
62
+ ), "You must provide vits phonemes on using addon get_vits_phoneme_ids"
63
+ clean_text = metadata["phonemes"] + "⚠"
64
+ sequence = []
65
+
66
+ for symbol in clean_text:
67
+ if symbol not in _symbol_to_id.keys():
68
+ print("%s is not in the vocabulary. %s" % (symbol, clean_text))
69
+ symbol = "_"
70
+ symbol_id = _symbol_to_id[symbol]
71
+ sequence += [symbol_id]
72
+
73
+ def _pad_phonemes(phonemes_list):
74
+ return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list))
75
+
76
+ sequence = sequence[:pad_length]
77
+
78
+ return {"phoneme_idx": torch.LongTensor(_pad_phonemes(sequence))}
79
+
80
+
81
+ def calculate_relative_bandwidth(config, dl_output, metadata):
82
+ assert "stft" in dl_output.keys()
83
+
84
+ # The last dimension of the stft feature is the frequency dimension
85
+ freq_dimensions = dl_output["stft"].size(-1)
86
+
87
+ freq_energy_dist = torch.sum(dl_output["stft"], dim=0)
88
+ freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0)
89
+ total_energy = freq_energy_dist[-1]
90
+
91
+ percentile_5th = total_energy * 0.05
92
+ percentile_95th = total_energy * 0.95
93
+
94
+ lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist))
95
+ higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist))
96
+
97
+ lower_idx = int((lower_idx / freq_dimensions) * 1000)
98
+ higher_idx = int((higher_idx / freq_dimensions) * 1000)
99
+
100
+ return {"freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx])}
101
+
102
+
103
+ def calculate_mel_spec_relative_bandwidth_as_extra_channel(config, dl_output, metadata):
104
+ assert "stft" in dl_output.keys()
105
+ linear_mel_spec = torch.exp(torch.clip(dl_output["log_mel_spec"], max=10))
106
+
107
+ # The last dimension of the stft feature is the frequency dimension
108
+ freq_dimensions = linear_mel_spec.size(-1)
109
+ freq_energy_dist = torch.sum(linear_mel_spec, dim=0)
110
+ freq_energy_dist = torch.cumsum(freq_energy_dist, dim=0)
111
+ total_energy = freq_energy_dist[-1]
112
+
113
+ percentile_5th = total_energy * 0.05
114
+ percentile_95th = total_energy * 0.95
115
+
116
+ lower_idx = torch.argmin(torch.abs(percentile_5th - freq_energy_dist))
117
+ higher_idx = torch.argmin(torch.abs(percentile_95th - freq_energy_dist))
118
+
119
+ latent_t_size = config["model"]["params"]["latent_t_size"]
120
+ latent_f_size = config["model"]["params"]["latent_f_size"]
121
+
122
+ lower_idx = int(latent_f_size * float((lower_idx / freq_dimensions)))
123
+ higher_idx = int(latent_f_size * float((higher_idx / freq_dimensions)))
124
+
125
+ bandwidth_condition = torch.zeros((latent_t_size, latent_f_size))
126
+ bandwidth_condition[:, lower_idx:higher_idx] += 1.0
127
+
128
+ return {
129
+ "mel_spec_bandwidth_cond_extra_channel": bandwidth_condition,
130
+ "freq_energy_percentile": torch.LongTensor([lower_idx, higher_idx]),
131
+ }
132
+
133
+
134
+ def waveform_rs_48k(config, dl_output, metadata):
135
+ waveform = dl_output["waveform"] # [1, samples]
136
+ sampling_rate = dl_output["sampling_rate"]
137
+
138
+ if sampling_rate != 48000:
139
+ waveform_48k = torchaudio.functional.resample(
140
+ waveform, orig_freq=sampling_rate, new_freq=48000
141
+ )
142
+ else:
143
+ waveform_48k = waveform
144
+
145
+ return {"waveform_48k": waveform_48k}
146
+
147
+
148
+ def extract_vits_phoneme_and_flant5_text(config, dl_output, metadata):
149
+ assert (
150
+ "phoneme" not in metadata.keys()
151
+ ), "The metadata of speech you use seems belong to fastspeech. Please check dataset_root.json"
152
+
153
+ if "phonemes" in metadata.keys():
154
+ new_item = get_vits_phoneme_ids_no_padding(config, dl_output, metadata)
155
+ new_item["text"] = "" # We assume TTS data does not have text description
156
+ else:
157
+ fake_metadata = {"phonemes": ""} # Add empty phoneme sequence
158
+ new_item = get_vits_phoneme_ids_no_padding(config, dl_output, fake_metadata)
159
+
160
+ return new_item
161
+
162
+
163
+ def extract_fs2_phoneme_and_flant5_text(config, dl_output, metadata):
164
+ if "phoneme" in metadata.keys():
165
+ new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata)
166
+ new_item["text"] = ""
167
+ else:
168
+ fake_metadata = {"phoneme": []}
169
+ new_item = extract_fs2_phoneme_g2p_en_feature(config, dl_output, fake_metadata)
170
+ return new_item
171
+
172
+
173
+ def extract_fs2_phoneme_g2p_en_feature(config, dl_output, metadata):
174
+ PAD_LENGTH = 135
175
+
176
+ phonemes_lookup_dict = {
177
+ "K": 0,
178
+ "IH2": 1,
179
+ "NG": 2,
180
+ "OW2": 3,
181
+ "AH2": 4,
182
+ "F": 5,
183
+ "AE0": 6,
184
+ "IY0": 7,
185
+ "SH": 8,
186
+ "G": 9,
187
+ "W": 10,
188
+ "UW1": 11,
189
+ "AO2": 12,
190
+ "AW2": 13,
191
+ "UW0": 14,
192
+ "EY2": 15,
193
+ "UW2": 16,
194
+ "AE2": 17,
195
+ "IH0": 18,
196
+ "P": 19,
197
+ "D": 20,
198
+ "ER1": 21,
199
+ "AA1": 22,
200
+ "EH0": 23,
201
+ "UH1": 24,
202
+ "N": 25,
203
+ "V": 26,
204
+ "AY1": 27,
205
+ "EY1": 28,
206
+ "UH2": 29,
207
+ "EH1": 30,
208
+ "L": 31,
209
+ "AA2": 32,
210
+ "R": 33,
211
+ "OY1": 34,
212
+ "Y": 35,
213
+ "ER2": 36,
214
+ "S": 37,
215
+ "AE1": 38,
216
+ "AH1": 39,
217
+ "JH": 40,
218
+ "ER0": 41,
219
+ "EH2": 42,
220
+ "IY2": 43,
221
+ "OY2": 44,
222
+ "AW1": 45,
223
+ "IH1": 46,
224
+ "IY1": 47,
225
+ "OW0": 48,
226
+ "AO0": 49,
227
+ "AY0": 50,
228
+ "EY0": 51,
229
+ "AY2": 52,
230
+ "UH0": 53,
231
+ "M": 54,
232
+ "TH": 55,
233
+ "T": 56,
234
+ "OY0": 57,
235
+ "AW0": 58,
236
+ "DH": 59,
237
+ "Z": 60,
238
+ "spn": 61,
239
+ "AH0": 62,
240
+ "sp": 63,
241
+ "AO1": 64,
242
+ "OW1": 65,
243
+ "ZH": 66,
244
+ "B": 67,
245
+ "AA0": 68,
246
+ "CH": 69,
247
+ "HH": 70,
248
+ }
249
+ pad_token_id = len(phonemes_lookup_dict.keys())
250
+
251
+ assert (
252
+ "phoneme" in metadata.keys()
253
+ ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset"
254
+
255
+ phonemes = [
256
+ phonemes_lookup_dict[x]
257
+ for x in metadata["phoneme"]
258
+ if (x in phonemes_lookup_dict.keys())
259
+ ]
260
+
261
+ if (len(phonemes) / PAD_LENGTH) > 5:
262
+ print(
263
+ "Warning: Phonemes length is too long and is truncated too much! %s"
264
+ % metadata
265
+ )
266
+
267
+ phonemes = phonemes[:PAD_LENGTH]
268
+
269
+ def _pad_phonemes(phonemes_list):
270
+ return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list))
271
+
272
+ return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))}
273
+
274
+
275
+ def extract_phoneme_g2p_en_feature(config, dl_output, metadata):
276
+ PAD_LENGTH = 250
277
+
278
+ phonemes_lookup_dict = {
279
+ " ": 0,
280
+ "AA": 1,
281
+ "AE": 2,
282
+ "AH": 3,
283
+ "AO": 4,
284
+ "AW": 5,
285
+ "AY": 6,
286
+ "B": 7,
287
+ "CH": 8,
288
+ "D": 9,
289
+ "DH": 10,
290
+ "EH": 11,
291
+ "ER": 12,
292
+ "EY": 13,
293
+ "F": 14,
294
+ "G": 15,
295
+ "HH": 16,
296
+ "IH": 17,
297
+ "IY": 18,
298
+ "JH": 19,
299
+ "K": 20,
300
+ "L": 21,
301
+ "M": 22,
302
+ "N": 23,
303
+ "NG": 24,
304
+ "OW": 25,
305
+ "OY": 26,
306
+ "P": 27,
307
+ "R": 28,
308
+ "S": 29,
309
+ "SH": 30,
310
+ "T": 31,
311
+ "TH": 32,
312
+ "UH": 33,
313
+ "UW": 34,
314
+ "V": 35,
315
+ "W": 36,
316
+ "Y": 37,
317
+ "Z": 38,
318
+ "ZH": 39,
319
+ }
320
+ pad_token_id = len(phonemes_lookup_dict.keys())
321
+
322
+ assert (
323
+ "phoneme" in metadata.keys()
324
+ ), "The dataloader add-on extract_phoneme_g2p_en_feature will output phoneme id, which is not specified in your dataset"
325
+ phonemes = [
326
+ phonemes_lookup_dict[x]
327
+ for x in metadata["phoneme"]
328
+ if (x in phonemes_lookup_dict.keys())
329
+ ]
330
+
331
+ if (len(phonemes) / PAD_LENGTH) > 5:
332
+ print(
333
+ "Warning: Phonemes length is too long and is truncated too much! %s"
334
+ % metadata
335
+ )
336
+
337
+ phonemes = phonemes[:PAD_LENGTH]
338
+
339
+ def _pad_phonemes(phonemes_list):
340
+ return phonemes_list + [pad_token_id] * (PAD_LENGTH - len(phonemes_list))
341
+
342
+ return {"phoneme_idx": torch.LongTensor(_pad_phonemes(phonemes))}
343
+
344
+
345
+ def extract_kaldi_fbank_feature(config, dl_output, metadata):
346
+ norm_mean = -4.2677393
347
+ norm_std = 4.5689974
348
+
349
+ waveform = dl_output["waveform"] # [1, samples]
350
+ sampling_rate = dl_output["sampling_rate"]
351
+ log_mel_spec_hifigan = dl_output["log_mel_spec"]
352
+
353
+ if sampling_rate != 16000:
354
+ waveform_16k = torchaudio.functional.resample(
355
+ waveform, orig_freq=sampling_rate, new_freq=16000
356
+ )
357
+ else:
358
+ waveform_16k = waveform
359
+
360
+ waveform_16k = waveform_16k - waveform_16k.mean()
361
+ fbank = torchaudio.compliance.kaldi.fbank(
362
+ waveform_16k,
363
+ htk_compat=True,
364
+ sample_frequency=16000,
365
+ use_energy=False,
366
+ window_type="hanning",
367
+ num_mel_bins=128,
368
+ dither=0.0,
369
+ frame_shift=10,
370
+ )
371
+
372
+ TARGET_LEN = log_mel_spec_hifigan.size(0)
373
+
374
+ # cut and pad
375
+ n_frames = fbank.shape[0]
376
+ p = TARGET_LEN - n_frames
377
+ if p > 0:
378
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
379
+ fbank = m(fbank)
380
+ elif p < 0:
381
+ fbank = fbank[:TARGET_LEN, :]
382
+
383
+ fbank = (fbank - norm_mean) / (norm_std * 2)
384
+
385
+ return {"ta_kaldi_fbank": fbank} # [1024, 128]
386
+
387
+
388
+ def extract_kaldi_fbank_feature_32k(config, dl_output, metadata):
389
+ norm_mean = -4.2677393
390
+ norm_std = 4.5689974
391
+
392
+ waveform = dl_output["waveform"] # [1, samples]
393
+ sampling_rate = dl_output["sampling_rate"]
394
+ log_mel_spec_hifigan = dl_output["log_mel_spec"]
395
+
396
+ if sampling_rate != 32000:
397
+ waveform_32k = torchaudio.functional.resample(
398
+ waveform, orig_freq=sampling_rate, new_freq=32000
399
+ )
400
+ else:
401
+ waveform_32k = waveform
402
+
403
+ waveform_32k = waveform_32k - waveform_32k.mean()
404
+ fbank = torchaudio.compliance.kaldi.fbank(
405
+ waveform_32k,
406
+ htk_compat=True,
407
+ sample_frequency=32000,
408
+ use_energy=False,
409
+ window_type="hanning",
410
+ num_mel_bins=128,
411
+ dither=0.0,
412
+ frame_shift=10,
413
+ )
414
+
415
+ TARGET_LEN = log_mel_spec_hifigan.size(0)
416
+
417
+ # cut and pad
418
+ n_frames = fbank.shape[0]
419
+ p = TARGET_LEN - n_frames
420
+ if p > 0:
421
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
422
+ fbank = m(fbank)
423
+ elif p < 0:
424
+ fbank = fbank[:TARGET_LEN, :]
425
+
426
+ fbank = (fbank - norm_mean) / (norm_std * 2)
427
+
428
+ return {"ta_kaldi_fbank": fbank} # [1024, 128]
429
+
430
+
431
+ # Use the beat and downbeat information as music conditions
432
+ def extract_drum_beat(config, dl_output, metadata):
433
+ def visualization(conditional_signal, mel_spectrogram, filename):
434
+ import soundfile as sf
435
+
436
+ sf.write(
437
+ os.path.basename(dl_output["fname"]),
438
+ np.array(dl_output["waveform"])[0],
439
+ dl_output["sampling_rate"],
440
+ )
441
+ plt.figure(figsize=(10, 10))
442
+
443
+ plt.subplot(211)
444
+ plt.imshow(np.array(conditional_signal).T, aspect="auto")
445
+ plt.title("Conditional Signal")
446
+
447
+ plt.subplot(212)
448
+ plt.imshow(np.array(mel_spectrogram).T, aspect="auto")
449
+ plt.title("Mel Spectrogram")
450
+
451
+ plt.savefig(filename)
452
+ plt.close()
453
+
454
+ assert "sample_rate" in metadata and "beat" in metadata and "downbeat" in metadata
455
+
456
+ sampling_rate = metadata["sample_rate"]
457
+ duration = dl_output["duration"]
458
+ # The dataloader segment length before performing torch resampling
459
+ original_segment_length_before_resample = int(sampling_rate * duration)
460
+
461
+ random_start_sample = int(dl_output["random_start_sample_in_original_audio_file"])
462
+
463
+ # The sample idx for beat and downbeat, relatively to the segmented audio
464
+ beat = [
465
+ x - random_start_sample
466
+ for x in metadata["beat"]
467
+ if (
468
+ x - random_start_sample >= 0
469
+ and x - random_start_sample <= original_segment_length_before_resample
470
+ )
471
+ ]
472
+ downbeat = [
473
+ x - random_start_sample
474
+ for x in metadata["downbeat"]
475
+ if (
476
+ x - random_start_sample >= 0
477
+ and x - random_start_sample <= original_segment_length_before_resample
478
+ )
479
+ ]
480
+
481
+ latent_shape = (
482
+ config["model"]["params"]["latent_t_size"],
483
+ config["model"]["params"]["latent_f_size"],
484
+ )
485
+ conditional_signal = torch.zeros(latent_shape)
486
+
487
+ # beat: -0.5
488
+ # downbeat: +1.0
489
+ # 0: none; -0.5: beat; 1.0: downbeat; 0.5: downbeat+beat
490
+ for each in beat:
491
+ beat_index = int(
492
+ (each / original_segment_length_before_resample) * latent_shape[0]
493
+ )
494
+ beat_index = min(beat_index, conditional_signal.size(0) - 1)
495
+
496
+ conditional_signal[beat_index, :] -= 0.5
497
+
498
+ for each in downbeat:
499
+ beat_index = int(
500
+ (each / original_segment_length_before_resample) * latent_shape[0]
501
+ )
502
+ beat_index = min(beat_index, conditional_signal.size(0) - 1)
503
+
504
+ conditional_signal[beat_index, :] += 1.0
505
+
506
+ # visualization(conditional_signal, dl_output["log_mel_spec"], filename = os.path.basename(dl_output["fname"])+".png")
507
+
508
+ return {"cond_beat_downbeat": conditional_signal}
audioldm_train/losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .contperceptual import LPIPSWithDiscriminator
audioldm_train/losses/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (203 Bytes). View file
 
audioldm_train/losses/__pycache__/contperceptual.cpython-310.pyc ADDED
Binary file (3.66 kB). View file
 
audioldm_train/losses/contperceptual.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ import sys
5
+ sys.path.append("/train20/intern/permanent/changli7/dataset_ptm")
6
+ from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
7
+
8
+
9
+ class LPIPSWithDiscriminator(nn.Module):
10
+ def __init__(
11
+ self,
12
+ disc_start,
13
+ logvar_init=0.0,
14
+ kl_weight=1.0,
15
+ pixelloss_weight=1.0,
16
+ disc_num_layers=3,
17
+ disc_in_channels=3,
18
+ disc_factor=1.0,
19
+ disc_weight=1.0,
20
+ perceptual_weight=1.0,
21
+ use_actnorm=False,
22
+ disc_conditional=False,
23
+ disc_loss="hinge",
24
+ ):
25
+ super().__init__()
26
+ assert disc_loss in ["hinge", "vanilla"]
27
+ self.kl_weight = kl_weight
28
+ self.pixel_weight = pixelloss_weight
29
+ self.perceptual_loss = LPIPS().eval()
30
+ self.perceptual_weight = perceptual_weight
31
+ # output log variance
32
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
33
+
34
+ self.discriminator = NLayerDiscriminator(
35
+ input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm
36
+ ).apply(weights_init)
37
+ self.discriminator_iter_start = disc_start
38
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
39
+ self.disc_factor = disc_factor
40
+ self.discriminator_weight = disc_weight
41
+ self.disc_conditional = disc_conditional
42
+
43
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
44
+ if last_layer is not None:
45
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
46
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
47
+ else:
48
+ nll_grads = torch.autograd.grad(
49
+ nll_loss, self.last_layer[0], retain_graph=True
50
+ )[0]
51
+ g_grads = torch.autograd.grad(
52
+ g_loss, self.last_layer[0], retain_graph=True
53
+ )[0]
54
+
55
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
56
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
57
+ d_weight = d_weight * self.discriminator_weight
58
+ return d_weight
59
+
60
+ def forward(
61
+ self,
62
+ inputs,
63
+ reconstructions,
64
+ posteriors,
65
+ optimizer_idx,
66
+ global_step,
67
+ waveform=None,
68
+ rec_waveform=None,
69
+ last_layer=None,
70
+ cond=None,
71
+ split="train",
72
+ weights=None,
73
+ ):
74
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
75
+
76
+ # Always true
77
+ if self.perceptual_weight > 0:
78
+ p_loss = self.perceptual_loss(
79
+ inputs.contiguous(), reconstructions.contiguous()
80
+ )
81
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
82
+
83
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
84
+ weighted_nll_loss = nll_loss
85
+ if weights is not None:
86
+ weighted_nll_loss = weights * nll_loss
87
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
88
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
89
+ kl_loss = posteriors.kl()
90
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
91
+
92
+ # now the GAN part
93
+ if optimizer_idx == 0:
94
+ # generator update
95
+ if cond is None:
96
+ assert not self.disc_conditional
97
+ logits_fake = self.discriminator(reconstructions.contiguous())
98
+ else:
99
+ assert self.disc_conditional
100
+ logits_fake = self.discriminator(
101
+ torch.cat((reconstructions.contiguous(), cond), dim=1)
102
+ )
103
+ g_loss = -torch.mean(logits_fake)
104
+
105
+ if self.disc_factor > 0.0:
106
+ try:
107
+ d_weight = self.calculate_adaptive_weight(
108
+ nll_loss, g_loss, last_layer=last_layer
109
+ )
110
+ except RuntimeError:
111
+ assert not self.training
112
+ d_weight = torch.tensor(0.0)
113
+ else:
114
+ d_weight = torch.tensor(0.0)
115
+
116
+ disc_factor = adopt_weight(
117
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
118
+ )
119
+ loss = (
120
+ weighted_nll_loss
121
+ + self.kl_weight * kl_loss
122
+ + d_weight * disc_factor * g_loss
123
+ )
124
+
125
+ log = {
126
+ "{}/total_loss".format(split): loss.clone().detach().mean(),
127
+ "{}/logvar".format(split): self.logvar.detach(),
128
+ "{}/kl_loss".format(split): kl_loss.detach().mean(),
129
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
130
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
131
+ "{}/d_weight".format(split): d_weight.detach(),
132
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
133
+ "{}/g_loss".format(split): g_loss.detach().mean(),
134
+ }
135
+ return loss, log
136
+
137
+ if optimizer_idx == 1:
138
+ # second pass for discriminator update
139
+ if cond is None:
140
+ logits_real = self.discriminator(inputs.contiguous().detach())
141
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
142
+ else:
143
+ logits_real = self.discriminator(
144
+ torch.cat((inputs.contiguous().detach(), cond), dim=1)
145
+ )
146
+ logits_fake = self.discriminator(
147
+ torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
148
+ )
149
+
150
+ disc_factor = adopt_weight(
151
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
152
+ )
153
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
154
+
155
+ log = {
156
+ "{}/disc_loss".format(split): d_loss.clone().detach().mean(),
157
+ "{}/logits_real".format(split): logits_real.detach().mean(),
158
+ "{}/logits_fake".format(split): logits_fake.detach().mean(),
159
+ }
160
+ return d_loss, log
audioldm_train/modules/.DS_Store ADDED
Binary file (8.2 kB). View file
 
audioldm_train/modules/__init__.py ADDED
File without changes
audioldm_train/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (143 Bytes). View file
 
audioldm_train/modules/audiomae/AudioMAE.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reference Repo: https://github.com/facebookresearch/AudioMAE
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from timm.models.layers import to_2tuple
8
+ import audioldm_train.modules.audiomae.models_vit as models_vit
9
+ import audioldm_train.modules.audiomae.models_mae as models_mae
10
+
11
+ # model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
12
+
13
+
14
+ class PatchEmbed_new(nn.Module):
15
+ """Flexible Image to Patch Embedding"""
16
+
17
+ def __init__(
18
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
19
+ ):
20
+ super().__init__()
21
+ img_size = to_2tuple(img_size)
22
+ patch_size = to_2tuple(patch_size)
23
+ stride = to_2tuple(stride)
24
+
25
+ self.img_size = img_size
26
+ self.patch_size = patch_size
27
+
28
+ self.proj = nn.Conv2d(
29
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride
30
+ ) # with overlapped patches
31
+ # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
32
+
33
+ # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
34
+ # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
35
+ _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
36
+ self.patch_hw = (h, w)
37
+ self.num_patches = h * w
38
+
39
+ def get_output_shape(self, img_size):
40
+ # todo: don't be lazy..
41
+ return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
42
+
43
+ def forward(self, x):
44
+ B, C, H, W = x.shape
45
+ # FIXME look at relaxing size constraints
46
+ # assert H == self.img_size[0] and W == self.img_size[1], \
47
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
48
+ x = self.proj(x)
49
+ x = x.flatten(2).transpose(1, 2)
50
+ return x
51
+
52
+
53
+ class AudioMAE(nn.Module):
54
+ """Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)"""
55
+
56
+ def __init__(
57
+ self,
58
+ ):
59
+ super().__init__()
60
+ model = models_vit.__dict__["vit_base_patch16"](
61
+ num_classes=527,
62
+ drop_path_rate=0.1,
63
+ global_pool=True,
64
+ mask_2d=True,
65
+ use_custom_patch=False,
66
+ )
67
+
68
+ img_size = (1024, 128)
69
+ emb_dim = 768
70
+
71
+ model.patch_embed = PatchEmbed_new(
72
+ img_size=img_size,
73
+ patch_size=(16, 16),
74
+ in_chans=1,
75
+ embed_dim=emb_dim,
76
+ stride=16,
77
+ )
78
+ num_patches = model.patch_embed.num_patches
79
+ # num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8
80
+ model.pos_embed = nn.Parameter(
81
+ torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False
82
+ ) # fixed sin-cos embedding
83
+
84
+ checkpoint_path = (
85
+ "/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth"
86
+ )
87
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
88
+ msg = model.load_state_dict(checkpoint["model"], strict=False)
89
+ # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
90
+
91
+ self.model = model
92
+
93
+ def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0):
94
+ """
95
+ x: mel fbank [Batch, 1, T, F]
96
+ mask_t_prob: 'T masking ratio (percentage of removed patches).'
97
+ mask_f_prob: 'F masking ratio (percentage of removed patches).'
98
+ """
99
+ return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)
100
+
101
+
102
+ class Vanilla_AudioMAE(nn.Module):
103
+ """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM)"""
104
+
105
+ def __init__(
106
+ self,
107
+ ):
108
+ super().__init__()
109
+ model = models_mae.__dict__["mae_vit_base_patch16"](
110
+ in_chans=1, audio_exp=True, img_size=(1024, 128)
111
+ )
112
+
113
+ checkpoint_path = "data/checkpoints/audiomae_16k_128bins.ckpt"
114
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
115
+ msg = model.load_state_dict(checkpoint["model"], strict=False)
116
+
117
+ # Skip the missing keys of decoder modules (not required)
118
+ # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
119
+
120
+ self.model = model.eval()
121
+
122
+ def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False):
123
+ """
124
+ x: mel fbank [Batch, 1, 1024 (T), 128 (F)]
125
+ mask_ratio: 'masking ratio (percentage of removed patches).'
126
+ """
127
+ with torch.no_grad():
128
+ # embed: [B, 513, 768] for mask_ratio=0.0
129
+ if no_mask:
130
+ if no_average:
131
+ raise RuntimeError("This function is deprecated")
132
+ embed = self.model.forward_encoder_no_random_mask_no_average(
133
+ x
134
+ ) # mask_ratio
135
+ else:
136
+ embed = self.model.forward_encoder_no_mask(x) # mask_ratio
137
+ else:
138
+ raise RuntimeError("This function is deprecated")
139
+ embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio)
140
+ return embed
141
+
142
+
143
+ if __name__ == "__main__":
144
+ model = Vanilla_AudioMAE().cuda()
145
+ input = torch.randn(4, 1, 1024, 128).cuda()
146
+ print("The first run")
147
+ embed = model(input, mask_ratio=0.0, no_mask=True)
148
+ print(embed)
149
+ print("The second run")
150
+ embed = model(input, mask_ratio=0.0)
151
+ print(embed)
audioldm_train/modules/audiomae/README.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A simple use of Audio Masked AutoEncoder (AudioMAE)
2
+ Reference code: https://github.com/facebookresearch/AudioMAE
3
+
4
+ Paper: https://arxiv.org/abs/2207.06405
5
+
6
+ Install the required python packages:
7
+ ```
8
+ pip install -r requirments.txt
9
+ ```
10
+
11
+
12
+ See the usage in example.py
13
+
14
+
15
+
16
+ ```
17
+ python example.py
18
+
19
+ """
20
+ Load AudioMAE from /mnt/bn/data-xubo/project/Masked_AudioEncoder checkpoint/finetuned.pth / message: <All keys matched successfully>
21
+ Start evaluation on AudioSet ...
22
+ mAP: 0.463003
23
+ """
24
+ ```
audioldm_train/modules/audiomae/__init__.py ADDED
File without changes
audioldm_train/modules/audiomae/__pycache__/AudioMAE.cpython-310.pyc ADDED
Binary file (4.48 kB). View file
 
audioldm_train/modules/audiomae/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (152 Bytes). View file
 
audioldm_train/modules/audiomae/__pycache__/models_mae.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
audioldm_train/modules/audiomae/__pycache__/models_vit.cpython-310.pyc ADDED
Binary file (5.18 kB). View file
 
audioldm_train/modules/audiomae/audiovisual_dataset.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from tqdm import tqdm
4
+ import torch
5
+ import decord
6
+
7
+ decord.bridge.set_bridge("torch")
8
+ import torchaudio
9
+ from math import ceil
10
+ from torch.utils.data import Dataset, DataLoader
11
+ import pandas as pd
12
+ import numpy as np
13
+
14
+
15
+ class AudioVisualDataset(Dataset):
16
+ """Can sample data from audio-visual databases
17
+ Params:
18
+ min_video_frames: used to drop short video clips
19
+ video_resize: resize for CLIP processing
20
+ sampling_rate: audio sampling rate
21
+ max_clip_len: max length (seconds) of audiovisual clip to be sampled
22
+ num_sample_frames: number of image frames to be uniformly sampled from video
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ datafiles=[
28
+ "/mnt/bn/data-xubo/dataset/audioset_videos/datafiles/audioset_balanced_train.json",
29
+ ],
30
+ min_video_frames=30,
31
+ video_resize=[224, 224],
32
+ sampling_rate=16000,
33
+ sample_av_clip=True,
34
+ max_clip_len=10,
35
+ num_sample_frames=10,
36
+ # hyparameters used for SpecAug
37
+ freqm=48,
38
+ timem=192,
39
+ return_label=False,
40
+ ):
41
+ all_data_json = []
42
+ for datafile in datafiles:
43
+ with open(datafile, "r") as fp:
44
+ data_json = json.load(fp)["data"]
45
+ all_data_json.extend(data_json)
46
+
47
+ # drop short video clips
48
+ self.all_data_json = [
49
+ data
50
+ for data in all_data_json
51
+ if int(data["video_shape"][0]) >= min_video_frames
52
+ ]
53
+
54
+ self.max_clip_len = max_clip_len
55
+ self.video_resize = video_resize
56
+ self.sampling_rate = sampling_rate
57
+ self.sample_av_clip = sample_av_clip
58
+ self.num_sample_frames = num_sample_frames
59
+ self.corresponding_audio_len = self.sampling_rate * self.max_clip_len
60
+
61
+ # hyparameters used for AudioMAE
62
+ self.freqm = freqm
63
+ self.timem = timem
64
+ self.norm_mean = -4.2677393
65
+ self.norm_std = 4.5689974
66
+ self.melbins = 128
67
+ self.TARGET_LEN = 1024
68
+
69
+ self.return_label = return_label
70
+ if self.return_label:
71
+ self.audioset_label2idx = self._prepare_audioset()
72
+
73
+ def __len__(self):
74
+ return len(self.all_data_json)
75
+
76
+ def _read_audio_video(self, index):
77
+ try:
78
+ video_path = self.all_data_json[index]["mp4"]
79
+ # read audio
80
+ ar = decord.AudioReader(
81
+ video_path, sample_rate=self.sampling_rate, mono=True
82
+ )
83
+ # read video frames
84
+ vr = decord.VideoReader(
85
+ video_path,
86
+ height=self.video_resize[0],
87
+ width=self.video_resize[1],
88
+ )
89
+
90
+ labels = self.all_data_json[index]["labels"]
91
+ return vr, ar, labels
92
+
93
+ except Exception as e:
94
+ print(f"error: {e} occurs, when loading {video_path}")
95
+ random_index = random.randint(0, len(self.all_data_json) - 1)
96
+ return self._read_audio_video(index=random_index)
97
+
98
+ def _prepare_audioset(self):
99
+ df1 = pd.read_csv(
100
+ "/mnt/bn/lqhaoheliu/datasets/audioset/metadata/class_labels_indices.csv",
101
+ delimiter=",",
102
+ skiprows=0,
103
+ )
104
+ label_set = df1.to_numpy()
105
+ code2id = {}
106
+ for i in range(len(label_set)):
107
+ code2id[label_set[i][1]] = label_set[i][0]
108
+ return code2id
109
+
110
+ def __getitem__(self, index):
111
+ # read audio and video
112
+ vr, ar, labels = self._read_audio_video(index)
113
+
114
+ # create a audio tensor
115
+ audio_data = ar[:] # [1, samples]
116
+ audio_len = audio_data.shape[1] / self.sampling_rate
117
+ audio_data = audio_data.squeeze(0) # [samples]
118
+
119
+ # create a video tensor
120
+ full_vid_length = len(vr)
121
+ video_rate = ceil(vr.get_avg_fps())
122
+ samples_per_frame = float(self.sampling_rate) / video_rate
123
+ start_frame = 0
124
+
125
+ # sample video clip
126
+ if audio_len > self.max_clip_len and self.sample_av_clip:
127
+ start_frame = random.randint(
128
+ 0, max(full_vid_length - video_rate * self.max_clip_len, 0)
129
+ )
130
+ end_frame = min(start_frame + video_rate * self.max_clip_len, full_vid_length)
131
+ video_data = vr.get_batch(range(start_frame, end_frame))
132
+
133
+ # sample audio clip
134
+ if audio_len > self.max_clip_len and self.sample_av_clip:
135
+ # corresponding_audio_len = int(video_data.size()[0] * samples_per_frame)
136
+ corresponding_audio_start = int(start_frame * samples_per_frame)
137
+ audio_data = audio_data[corresponding_audio_start:]
138
+
139
+ # cut or pad audio clip with respect to the sampled video clip
140
+ if audio_data.shape[0] < self.corresponding_audio_len:
141
+ zero_data = torch.zeros(self.corresponding_audio_len)
142
+ zero_data[: audio_data.shape[0]] = audio_data
143
+ audio_data = zero_data
144
+ elif audio_data.shape[0] > self.corresponding_audio_len:
145
+ audio_data = audio_data[: self.corresponding_audio_len]
146
+
147
+ # uniformly sample image frames from video [tentative solution]
148
+ interval = video_data.shape[0] // self.num_sample_frames
149
+ video_data = video_data[::interval][: self.num_sample_frames]
150
+
151
+ assert (
152
+ video_data.shape[0] == self.num_sample_frames
153
+ ), f"number of sampled image frames is {video_data.shape[0]}"
154
+
155
+ assert (
156
+ audio_data.shape[0] == self.corresponding_audio_len
157
+ ), f"number of audio samples is {audio_data.shape[0]}"
158
+
159
+ # video transformation
160
+ video_data = video_data / 255.0
161
+ video_data = video_data.permute(0, 3, 1, 2) # [N, H, W, C] -> [N, C, H, W]
162
+
163
+ # calculate mel fbank of waveform for audio encoder
164
+ audio_data = audio_data.unsqueeze(0) # [1, samples]
165
+ audio_data = audio_data - audio_data.mean()
166
+ fbank = torchaudio.compliance.kaldi.fbank(
167
+ audio_data,
168
+ htk_compat=True,
169
+ sample_frequency=self.sampling_rate,
170
+ use_energy=False,
171
+ window_type="hanning",
172
+ num_mel_bins=self.melbins,
173
+ dither=0.0,
174
+ frame_shift=10,
175
+ )
176
+ # cut and pad
177
+ n_frames = fbank.shape[0]
178
+ p = self.TARGET_LEN - n_frames
179
+ if p > 0:
180
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
181
+ fbank = m(fbank)
182
+ elif p < 0:
183
+ fbank = fbank[0 : self.TARGET_LEN, :]
184
+
185
+ # SpecAug for training (not for eval)
186
+ freqm = torchaudio.transforms.FrequencyMasking(self.freqm)
187
+ timem = torchaudio.transforms.TimeMasking(self.timem)
188
+ fbank = fbank.transpose(0, 1).unsqueeze(0) # 1, 128, 1024 (...,freq,time)
189
+ if self.freqm != 0:
190
+ fbank = freqm(fbank)
191
+ if self.timem != 0:
192
+ fbank = timem(fbank) # (..., freq, time)
193
+ fbank = torch.transpose(fbank.squeeze(), 0, 1) # time, freq
194
+ fbank = (fbank - self.norm_mean) / (self.norm_std * 2)
195
+ fbank = fbank.unsqueeze(0)
196
+
197
+ if self.return_label:
198
+ # get audioset lebel indexes
199
+ label_indices = np.zeros(527)
200
+
201
+ for label_str in labels.split(","):
202
+ label_indices[int(self.audioset_label2idx[label_str])] = 1.0
203
+
204
+ label_indices = torch.FloatTensor(label_indices)
205
+
206
+ data_dict = {
207
+ "labels": label_indices,
208
+ "images": video_data,
209
+ "fbank": fbank,
210
+ # 'modality': 'audio_visual'
211
+ }
212
+
213
+ else:
214
+ data_dict = {
215
+ "images": video_data,
216
+ "fbank": fbank,
217
+ # 'modality': 'audio_visual'
218
+ }
219
+
220
+ return data_dict
221
+
222
+
223
+ def collate_fn(list_data_dict):
224
+ r"""Collate mini-batch data to inputs and targets for training.
225
+
226
+ Args:
227
+ list_data_dict: e.g., [
228
+ {'vocals': (channels_num, segment_samples),
229
+ 'accompaniment': (channels_num, segment_samples),
230
+ 'mixture': (channels_num, segment_samples)
231
+ },
232
+ {'vocals': (channels_num, segment_samples),
233
+ 'accompaniment': (channels_num, segment_samples),
234
+ 'mixture': (channels_num, segment_samples)
235
+ },
236
+ ...]
237
+
238
+ Returns:
239
+ data_dict: e.g. {
240
+ 'vocals': (batch_size, channels_num, segment_samples),
241
+ 'accompaniment': (batch_size, channels_num, segment_samples),
242
+ 'mixture': (batch_size, channels_num, segment_samples)
243
+ }
244
+ """
245
+
246
+ data_dict = {}
247
+ for key in list_data_dict[0].keys():
248
+ # for key in ['waveform']:
249
+ # try:
250
+ data_dict[key] = [data_dict[key] for data_dict in list_data_dict]
251
+ # except:
252
+ # from IPython import embed; embed(using=False); os._exit(0)
253
+
254
+ data_dict[key] = torch.stack(data_dict[key])
255
+
256
+ return data_dict
audioldm_train/modules/audiomae/example.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from timm.models.layers import to_2tuple
5
+ import models_vit
6
+ from audiovisual_dataset import AudioVisualDataset, collate_fn
7
+ from torch.utils.data import DataLoader
8
+ from util.stat import calculate_stats
9
+ from tqdm import tqdm
10
+ from AudioMAE import AudioMAE
11
+
12
+ if __name__ == "__main__":
13
+ device = "cuda"
14
+ dataset = AudioVisualDataset(
15
+ datafiles=[
16
+ "/mnt/bn/data-xubo/dataset/audioset_videos/datafiles/audioset_eval.json"
17
+ ],
18
+ # disable SpecAug during evaluation
19
+ freqm=0,
20
+ timem=0,
21
+ return_label=True,
22
+ )
23
+
24
+ model = AudioMAE().to(device)
25
+ model.eval()
26
+
27
+ outputs = []
28
+ targets = []
29
+
30
+ dataloader = DataLoader(
31
+ dataset, batch_size=64, num_workers=8, shuffle=False, collate_fn=collate_fn
32
+ )
33
+
34
+ print("Start evaluation on AudioSet ...")
35
+ with torch.no_grad():
36
+ for data in tqdm(dataloader):
37
+ fbank = data["fbank"] # [B, 1, T, F]
38
+ fbank = fbank.to(device)
39
+ output = model(fbank, mask_t_prob=0.0, mask_f_prob=0.0)
40
+ target = data["labels"]
41
+ outputs.append(output)
42
+ targets.append(target)
43
+
44
+ outputs = torch.cat(outputs).cpu().numpy()
45
+ targets = torch.cat(targets).cpu().numpy()
46
+ stats = calculate_stats(outputs, targets)
47
+
48
+ AP = [stat["AP"] for stat in stats]
49
+ mAP = np.mean([stat["AP"] for stat in stats])
50
+ print("Done ... mAP: {:.6f}".format(mAP))
51
+
52
+ # mAP: 0.463003
audioldm_train/modules/audiomae/models_mae.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+ from json import encoder
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+ from timm.models.vision_transformer import Block
19
+ from audioldm_train.modules.audiomae.util.pos_embed import (
20
+ get_2d_sincos_pos_embed,
21
+ get_2d_sincos_pos_embed_flexible,
22
+ get_1d_sincos_pos_embed_from_grid,
23
+ )
24
+ from audioldm_train.modules.audiomae.util.patch_embed import (
25
+ PatchEmbed_new,
26
+ PatchEmbed_org,
27
+ )
28
+
29
+
30
+ class MaskedAutoencoderViT(nn.Module):
31
+ """Masked Autoencoder with VisionTransformer backbone"""
32
+
33
+ def __init__(
34
+ self,
35
+ img_size=224,
36
+ patch_size=16,
37
+ stride=10,
38
+ in_chans=3,
39
+ embed_dim=1024,
40
+ depth=24,
41
+ num_heads=16,
42
+ decoder_embed_dim=512,
43
+ decoder_depth=8,
44
+ decoder_num_heads=16,
45
+ mlp_ratio=4.0,
46
+ norm_layer=nn.LayerNorm,
47
+ norm_pix_loss=False,
48
+ audio_exp=False,
49
+ alpha=0.0,
50
+ temperature=0.2,
51
+ mode=0,
52
+ contextual_depth=8,
53
+ use_custom_patch=False,
54
+ split_pos=False,
55
+ pos_trainable=False,
56
+ use_nce=False,
57
+ beta=4.0,
58
+ decoder_mode=0,
59
+ mask_t_prob=0.6,
60
+ mask_f_prob=0.5,
61
+ mask_2d=False,
62
+ epoch=0,
63
+ no_shift=False,
64
+ ):
65
+ super().__init__()
66
+
67
+ self.audio_exp = audio_exp
68
+ self.embed_dim = embed_dim
69
+ self.decoder_embed_dim = decoder_embed_dim
70
+ # --------------------------------------------------------------------------
71
+ # MAE encoder specifics
72
+ if use_custom_patch:
73
+ print(
74
+ f"Use custom patch_emb with patch size: {patch_size}, stride: {stride}"
75
+ )
76
+ self.patch_embed = PatchEmbed_new(
77
+ img_size=img_size,
78
+ patch_size=patch_size,
79
+ in_chans=in_chans,
80
+ embed_dim=embed_dim,
81
+ stride=stride,
82
+ )
83
+ else:
84
+ self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim)
85
+ self.use_custom_patch = use_custom_patch
86
+ num_patches = self.patch_embed.num_patches
87
+
88
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
89
+
90
+ # self.split_pos = split_pos # not useful
91
+ self.pos_embed = nn.Parameter(
92
+ torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable
93
+ ) # fixed sin-cos embedding
94
+
95
+ self.encoder_depth = depth
96
+ self.contextual_depth = contextual_depth
97
+ self.blocks = nn.ModuleList(
98
+ [
99
+ Block(
100
+ embed_dim,
101
+ num_heads,
102
+ mlp_ratio,
103
+ qkv_bias=True,
104
+ norm_layer=norm_layer,
105
+ ) # qk_scale=None
106
+ for i in range(depth)
107
+ ]
108
+ )
109
+ self.norm = norm_layer(embed_dim)
110
+
111
+ # --------------------------------------------------------------------------
112
+ # MAE decoder specifics
113
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
114
+
115
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
116
+ self.decoder_pos_embed = nn.Parameter(
117
+ torch.zeros(1, num_patches + 1, decoder_embed_dim),
118
+ requires_grad=pos_trainable,
119
+ ) # fixed sin-cos embedding
120
+
121
+ self.no_shift = no_shift
122
+
123
+ self.decoder_mode = decoder_mode
124
+ if (
125
+ self.use_custom_patch
126
+ ): # overlapped patches as in AST. Similar performance yet compute heavy
127
+ window_size = (6, 6)
128
+ feat_size = (102, 12)
129
+ else:
130
+ window_size = (4, 4)
131
+ feat_size = (64, 8)
132
+ if self.decoder_mode == 1:
133
+ decoder_modules = []
134
+ for index in range(16):
135
+ if self.no_shift:
136
+ shift_size = (0, 0)
137
+ else:
138
+ if (index % 2) == 0:
139
+ shift_size = (0, 0)
140
+ else:
141
+ shift_size = (2, 0)
142
+ # shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size])
143
+ decoder_modules.append(
144
+ SwinTransformerBlock(
145
+ dim=decoder_embed_dim,
146
+ num_heads=16,
147
+ feat_size=feat_size,
148
+ window_size=window_size,
149
+ shift_size=shift_size,
150
+ mlp_ratio=mlp_ratio,
151
+ drop=0.0,
152
+ drop_attn=0.0,
153
+ drop_path=0.0,
154
+ extra_norm=False,
155
+ sequential_attn=False,
156
+ norm_layer=norm_layer, # nn.LayerNorm,
157
+ )
158
+ )
159
+ self.decoder_blocks = nn.ModuleList(decoder_modules)
160
+ else:
161
+ # Transfomer
162
+ self.decoder_blocks = nn.ModuleList(
163
+ [
164
+ Block(
165
+ decoder_embed_dim,
166
+ decoder_num_heads,
167
+ mlp_ratio,
168
+ qkv_bias=True,
169
+ norm_layer=norm_layer,
170
+ ) # qk_scale=None,
171
+ for i in range(decoder_depth)
172
+ ]
173
+ )
174
+
175
+ self.decoder_norm = norm_layer(decoder_embed_dim)
176
+ self.decoder_pred = nn.Linear(
177
+ decoder_embed_dim, patch_size**2 * in_chans, bias=True
178
+ ) # decoder to patch
179
+
180
+ # --------------------------------------------------------------------------
181
+
182
+ self.norm_pix_loss = norm_pix_loss
183
+
184
+ self.patch_size = patch_size
185
+ self.stride = stride
186
+
187
+ # audio exps
188
+ self.alpha = alpha
189
+ self.T = temperature
190
+ self.mode = mode
191
+ self.use_nce = use_nce
192
+ self.beta = beta
193
+
194
+ self.log_softmax = nn.LogSoftmax(dim=-1)
195
+
196
+ self.mask_t_prob = mask_t_prob
197
+ self.mask_f_prob = mask_f_prob
198
+ self.mask_2d = mask_2d
199
+
200
+ self.epoch = epoch
201
+
202
+ self.initialize_weights()
203
+
204
+ def initialize_weights(self):
205
+ # initialization
206
+ # initialize (and freeze) pos_embed by sin-cos embedding
207
+ if self.audio_exp:
208
+ pos_embed = get_2d_sincos_pos_embed_flexible(
209
+ self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True
210
+ )
211
+ else:
212
+ pos_embed = get_2d_sincos_pos_embed(
213
+ self.pos_embed.shape[-1],
214
+ int(self.patch_embed.num_patches**0.5),
215
+ cls_token=True,
216
+ )
217
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
218
+
219
+ if self.audio_exp:
220
+ decoder_pos_embed = get_2d_sincos_pos_embed_flexible(
221
+ self.decoder_pos_embed.shape[-1],
222
+ self.patch_embed.patch_hw,
223
+ cls_token=True,
224
+ )
225
+ else:
226
+ decoder_pos_embed = get_2d_sincos_pos_embed(
227
+ self.decoder_pos_embed.shape[-1],
228
+ int(self.patch_embed.num_patches**0.5),
229
+ cls_token=True,
230
+ )
231
+ self.decoder_pos_embed.data.copy_(
232
+ torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
233
+ )
234
+
235
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
236
+ w = self.patch_embed.proj.weight.data
237
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
238
+
239
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
240
+ torch.nn.init.normal_(self.cls_token, std=0.02)
241
+ torch.nn.init.normal_(self.mask_token, std=0.02)
242
+
243
+ # initialize nn.Linear and nn.LayerNorm
244
+ self.apply(self._init_weights)
245
+
246
+ def _init_weights(self, m):
247
+ if isinstance(m, nn.Linear):
248
+ # we use xavier_uniform following official JAX ViT:
249
+ torch.nn.init.xavier_uniform_(m.weight)
250
+ if isinstance(m, nn.Linear) and m.bias is not None:
251
+ nn.init.constant_(m.bias, 0)
252
+ elif isinstance(m, nn.LayerNorm):
253
+ nn.init.constant_(m.bias, 0)
254
+ nn.init.constant_(m.weight, 1.0)
255
+
256
+ def patchify(self, imgs):
257
+ """
258
+ imgs: (N, 3, H, W)
259
+ x: (N, L, patch_size**2 *3)
260
+ L = (H/p)*(W/p)
261
+ """
262
+ p = self.patch_embed.patch_size[0]
263
+ # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
264
+
265
+ if self.audio_exp:
266
+ if self.use_custom_patch: # overlapped patch
267
+ h, w = self.patch_embed.patch_hw
268
+ # todo: fixed h/w patch size and stride size. Make hw custom in the future
269
+ x = imgs.unfold(2, self.patch_size, self.stride).unfold(
270
+ 3, self.patch_size, self.stride
271
+ ) # n,1,H,W -> n,1,h,w,p,p
272
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
273
+ # x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
274
+ # x = torch.einsum('nchpwq->nhwpqc', x)
275
+ # x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
276
+ else:
277
+ h = imgs.shape[2] // p
278
+ w = imgs.shape[3] // p
279
+ # h,w = self.patch_embed.patch_hw
280
+ x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
281
+ x = torch.einsum("nchpwq->nhwpqc", x)
282
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
283
+ else:
284
+ h = w = imgs.shape[2] // p
285
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
286
+ x = torch.einsum("nchpwq->nhwpqc", x)
287
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
288
+
289
+ return x
290
+
291
+ def unpatchify(self, x):
292
+ """
293
+ x: (N, L, patch_size**2 *3)
294
+ specs: (N, 1, H, W)
295
+ """
296
+ p = self.patch_embed.patch_size[0]
297
+ h = 1024 // p
298
+ w = 128 // p
299
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 1))
300
+ x = torch.einsum("nhwpqc->nchpwq", x)
301
+ specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p))
302
+ return specs
303
+
304
+ def random_masking(self, x, mask_ratio):
305
+ """
306
+ Perform per-sample random masking by per-sample shuffling.
307
+ Per-sample shuffling is done by argsort random noise.
308
+ x: [N, L, D], sequence
309
+ """
310
+ N, L, D = x.shape # batch, length, dim
311
+ len_keep = int(L * (1 - mask_ratio))
312
+
313
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
314
+
315
+ # sort noise for each sample
316
+ ids_shuffle = torch.argsort(
317
+ noise, dim=1
318
+ ) # ascend: small is keep, large is remove
319
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
320
+
321
+ # keep the first subset
322
+ ids_keep = ids_shuffle[:, :len_keep]
323
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
324
+
325
+ # generate the binary mask: 0 is keep, 1 is remove
326
+ mask = torch.ones([N, L], device=x.device)
327
+ mask[:, :len_keep] = 0
328
+ # unshuffle to get the binary mask
329
+ mask = torch.gather(mask, dim=1, index=ids_restore)
330
+
331
+ return x_masked, mask, ids_restore
332
+
333
+ def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
334
+ """
335
+ 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
336
+ Perform per-sample random masking by per-sample shuffling.
337
+ Per-sample shuffling is done by argsort random noise.
338
+ x: [N, L, D], sequence
339
+ """
340
+ N, L, D = x.shape # batch, length, dim
341
+ if self.use_custom_patch: # overlapped patch
342
+ T = 101
343
+ F = 12
344
+ else:
345
+ T = 64
346
+ F = 8
347
+ # x = x.reshape(N, T, F, D)
348
+ len_keep_t = int(T * (1 - mask_t_prob))
349
+ len_keep_f = int(F * (1 - mask_f_prob))
350
+
351
+ # noise for mask in time
352
+ noise_t = torch.rand(N, T, device=x.device) # noise in [0, 1]
353
+ # sort noise for each sample aling time
354
+ ids_shuffle_t = torch.argsort(
355
+ noise_t, dim=1
356
+ ) # ascend: small is keep, large is remove
357
+ ids_restore_t = torch.argsort(ids_shuffle_t, dim=1)
358
+ ids_keep_t = ids_shuffle_t[:, :len_keep_t]
359
+ # noise mask in freq
360
+ noise_f = torch.rand(N, F, device=x.device) # noise in [0, 1]
361
+ ids_shuffle_f = torch.argsort(
362
+ noise_f, dim=1
363
+ ) # ascend: small is keep, large is remove
364
+ ids_restore_f = torch.argsort(ids_shuffle_f, dim=1)
365
+ ids_keep_f = ids_shuffle_f[:, :len_keep_f] #
366
+
367
+ # generate the binary mask: 0 is keep, 1 is remove
368
+ # mask in freq
369
+ mask_f = torch.ones(N, F, device=x.device)
370
+ mask_f[:, :len_keep_f] = 0
371
+ mask_f = (
372
+ torch.gather(mask_f, dim=1, index=ids_restore_f)
373
+ .unsqueeze(1)
374
+ .repeat(1, T, 1)
375
+ ) # N,T,F
376
+ # mask in time
377
+ mask_t = torch.ones(N, T, device=x.device)
378
+ mask_t[:, :len_keep_t] = 0
379
+ mask_t = (
380
+ torch.gather(mask_t, dim=1, index=ids_restore_t)
381
+ .unsqueeze(1)
382
+ .repeat(1, F, 1)
383
+ .permute(0, 2, 1)
384
+ ) # N,T,F
385
+ mask = 1 - (1 - mask_t) * (1 - mask_f) # N, T, F
386
+
387
+ # get masked x
388
+ id2res = torch.Tensor(list(range(N * T * F))).reshape(N, T, F).to(x.device)
389
+ id2res = id2res + 999 * mask # add a large value for masked elements
390
+ id2res2 = torch.argsort(id2res.flatten(start_dim=1))
391
+ ids_keep = id2res2.flatten(start_dim=1)[:, : len_keep_f * len_keep_t]
392
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
393
+
394
+ ids_restore = torch.argsort(id2res2.flatten(start_dim=1))
395
+ mask = mask.flatten(start_dim=1)
396
+
397
+ return x_masked, mask, ids_restore
398
+
399
+ def forward_encoder(self, x, mask_ratio, mask_2d=False):
400
+ # embed patches
401
+ x = self.patch_embed(x)
402
+ # add pos embed w/o cls token
403
+ x = x + self.pos_embed[:, 1:, :]
404
+
405
+ # masking: length -> length * mask_ratio
406
+ if mask_2d:
407
+ x, mask, ids_restore = self.random_masking_2d(
408
+ x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob
409
+ )
410
+ else:
411
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
412
+
413
+ # append cls token
414
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
415
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
416
+ x = torch.cat((cls_tokens, x), dim=1)
417
+
418
+ # apply Transformer blocks
419
+ for blk in self.blocks:
420
+ x = blk(x)
421
+ x = self.norm(x)
422
+
423
+ return x, mask, ids_restore, None
424
+
425
+ def forward_encoder_no_random_mask_no_average(self, x):
426
+ # embed patches
427
+ x = self.patch_embed(x)
428
+ # add pos embed w/o cls token
429
+ x = x + self.pos_embed[:, 1:, :]
430
+
431
+ # masking: length -> length * mask_ratio
432
+ # if mask_2d:
433
+ # x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob)
434
+ # else:
435
+ # x, mask, ids_restore = self.random_masking(x, mask_ratio)
436
+
437
+ # append cls token
438
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
439
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
440
+ x = torch.cat((cls_tokens, x), dim=1)
441
+
442
+ # apply Transformer blocks
443
+ for blk in self.blocks:
444
+ x = blk(x)
445
+ x = self.norm(x)
446
+
447
+ return x
448
+
449
+ def forward_encoder_no_mask(self, x):
450
+ # embed patches
451
+ x = self.patch_embed(x)
452
+
453
+ # add pos embed w/o cls token
454
+ x = x + self.pos_embed[:, 1:, :]
455
+
456
+ # masking: length -> length * mask_ratio
457
+ # x, mask, ids_restore = self.random_masking(x, mask_ratio)
458
+ # append cls token
459
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
460
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
461
+ x = torch.cat((cls_tokens, x), dim=1)
462
+
463
+ # apply Transformer blocks
464
+ contextual_embs = []
465
+ for n, blk in enumerate(self.blocks):
466
+ x = blk(x)
467
+ if n > self.contextual_depth:
468
+ contextual_embs.append(self.norm(x))
469
+ # x = self.norm(x)
470
+ contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0)
471
+
472
+ return contextual_emb
473
+
474
+ def forward_decoder(self, x, ids_restore):
475
+ # embed tokens
476
+ x = self.decoder_embed(x)
477
+
478
+ # append mask tokens to sequence
479
+ mask_tokens = self.mask_token.repeat(
480
+ x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
481
+ )
482
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
483
+ x_ = torch.gather(
484
+ x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
485
+ ) # unshuffle
486
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
487
+
488
+ # add pos embed
489
+ x = x + self.decoder_pos_embed
490
+
491
+ if self.decoder_mode != 0:
492
+ B, L, D = x.shape
493
+ x = x[:, 1:, :]
494
+ if self.use_custom_patch:
495
+ x = x.reshape(B, 101, 12, D)
496
+ x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1) # hack
497
+ x = x.reshape(B, 1224, D)
498
+ if self.decoder_mode > 3: # mvit
499
+ x = self.decoder_blocks(x)
500
+ else:
501
+ # apply Transformer blocks
502
+ for blk in self.decoder_blocks:
503
+ x = blk(x)
504
+ x = self.decoder_norm(x)
505
+
506
+ # predictor projection
507
+ pred = self.decoder_pred(x)
508
+
509
+ # remove cls token
510
+ if self.decoder_mode != 0:
511
+ if self.use_custom_patch:
512
+ pred = pred.reshape(B, 102, 12, 256)
513
+ pred = pred[:, :101, :, :]
514
+ pred = pred.reshape(B, 1212, 256)
515
+ else:
516
+ pred = pred
517
+ else:
518
+ pred = pred[:, 1:, :]
519
+ return pred, None, None # emb, emb_pixel
520
+
521
+ def forward_loss(self, imgs, pred, mask, norm_pix_loss=False):
522
+ """
523
+ imgs: [N, 3, H, W]
524
+ pred: [N, L, p*p*3]
525
+ mask: [N, L], 0 is keep, 1 is remove,
526
+ """
527
+ target = self.patchify(imgs)
528
+ if norm_pix_loss:
529
+ mean = target.mean(dim=-1, keepdim=True)
530
+ var = target.var(dim=-1, keepdim=True)
531
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
532
+
533
+ loss = (pred - target) ** 2
534
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
535
+
536
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
537
+ return loss
538
+
539
+ def forward(self, imgs, mask_ratio=0.8):
540
+ emb_enc, mask, ids_restore, _ = self.forward_encoder(
541
+ imgs, mask_ratio, mask_2d=self.mask_2d
542
+ )
543
+ pred, _, _ = self.forward_decoder(emb_enc, ids_restore) # [N, L, p*p*3]
544
+ loss_recon = self.forward_loss(
545
+ imgs, pred, mask, norm_pix_loss=self.norm_pix_loss
546
+ )
547
+ loss_contrastive = torch.FloatTensor([0.0]).cuda()
548
+ return loss_recon, pred, mask, loss_contrastive
549
+
550
+
551
+ def mae_vit_small_patch16_dec512d8b(**kwargs):
552
+ model = MaskedAutoencoderViT(
553
+ patch_size=16,
554
+ embed_dim=384,
555
+ depth=12,
556
+ num_heads=6,
557
+ decoder_embed_dim=512,
558
+ decoder_num_heads=16,
559
+ mlp_ratio=4,
560
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
561
+ **kwargs,
562
+ )
563
+ return model
564
+
565
+
566
+ def mae_vit_base_patch16_dec512d8b(**kwargs):
567
+ model = MaskedAutoencoderViT(
568
+ patch_size=16,
569
+ embed_dim=768,
570
+ depth=12,
571
+ num_heads=12,
572
+ decoder_embed_dim=512,
573
+ decoder_num_heads=16,
574
+ mlp_ratio=4,
575
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
576
+ **kwargs,
577
+ )
578
+ return model
579
+
580
+
581
+ def mae_vit_large_patch16_dec512d8b(**kwargs):
582
+ model = MaskedAutoencoderViT(
583
+ patch_size=16,
584
+ embed_dim=1024,
585
+ depth=24,
586
+ num_heads=16,
587
+ decoder_embed_dim=512,
588
+ decoder_num_heads=16,
589
+ mlp_ratio=4,
590
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
591
+ **kwargs,
592
+ )
593
+ return model
594
+
595
+
596
+ def mae_vit_huge_patch14_dec512d8b(**kwargs):
597
+ model = MaskedAutoencoderViT(
598
+ patch_size=14,
599
+ embed_dim=1280,
600
+ depth=32,
601
+ num_heads=16,
602
+ decoder_embed_dim=512,
603
+ decoder_num_heads=16,
604
+ mlp_ratio=4,
605
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
606
+ **kwargs,
607
+ )
608
+ return model
609
+
610
+
611
+ # set recommended archs
612
+ mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
613
+ mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
614
+ mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
615
+ mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks
audioldm_train/modules/audiomae/models_vit.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import numpy as np
17
+ import timm.models.vision_transformer
18
+ from timm.models.vision_transformer import PatchEmbed, Block
19
+ from audioldm_train.modules.audiomae.util.patch_embed import (
20
+ PatchEmbed_new,
21
+ PatchEmbed3D_new,
22
+ )
23
+
24
+
25
+ class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
26
+ """Vision Transformer with support for global average pooling"""
27
+
28
+ def __init__(
29
+ self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs
30
+ ):
31
+ super(VisionTransformer, self).__init__(**kwargs)
32
+
33
+ self.global_pool = global_pool
34
+ if self.global_pool:
35
+ norm_layer = kwargs["norm_layer"]
36
+ embed_dim = kwargs["embed_dim"]
37
+ self.fc_norm = norm_layer(embed_dim)
38
+ del self.norm # remove the original norm
39
+ self.mask_2d = mask_2d
40
+ self.use_custom_patch = use_custom_patch
41
+ num_heads = 12
42
+ depth = 12
43
+ mlp_ratio = 4
44
+
45
+ def forward_features(self, x):
46
+ B = x.shape[0]
47
+ x = self.patch_embed(x)
48
+ x = x + self.pos_embed[:, 1:, :]
49
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
50
+ cls_tokens = cls_token.expand(
51
+ B, -1, -1
52
+ ) # stole cls_tokens impl from Phil Wang, thanks
53
+ x = torch.cat((cls_tokens, x), dim=1)
54
+ x = self.pos_drop(x)
55
+
56
+ for blk in self.blocks:
57
+ x = blk(x)
58
+
59
+ if self.global_pool:
60
+ x = x[:, 1:, :].mean(dim=1) # global pool without cls token
61
+ outcome = self.fc_norm(x)
62
+ else:
63
+ x = self.norm(x)
64
+ outcome = x[:, 0]
65
+
66
+ return outcome
67
+
68
+ def random_masking(self, x, mask_ratio):
69
+ """
70
+ Perform per-sample random masking by per-sample shuffling.
71
+ Per-sample shuffling is done by argsort random noise.
72
+ x: [N, L, D], sequence
73
+ """
74
+ N, L, D = x.shape # batch, length, dim
75
+ len_keep = int(L * (1 - mask_ratio))
76
+
77
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
78
+
79
+ # sort noise for each sample
80
+ ids_shuffle = torch.argsort(
81
+ noise, dim=1
82
+ ) # ascend: small is keep, large is remove
83
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
84
+
85
+ # keep the first subset
86
+ ids_keep = ids_shuffle[:, :len_keep]
87
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
88
+
89
+ # generate the binary mask: 0 is keep, 1 is remove
90
+ mask = torch.ones([N, L], device=x.device)
91
+ mask[:, :len_keep] = 0
92
+ # unshuffle to get the binary mask
93
+ mask = torch.gather(mask, dim=1, index=ids_restore)
94
+
95
+ return x_masked, mask, ids_restore
96
+
97
+ def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
98
+ """
99
+ 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
100
+ Perform per-sample random masking by per-sample shuffling.
101
+ Per-sample shuffling is done by argsort random noise.
102
+ x: [N, L, D], sequence
103
+ """
104
+
105
+ N, L, D = x.shape # batch, length, dim
106
+ if self.use_custom_patch:
107
+ # # for AS
108
+ T = 101 # 64,101
109
+ F = 12 # 8,12
110
+ # # for ESC
111
+ # T=50
112
+ # F=12
113
+ # for SPC
114
+ # T=12
115
+ # F=12
116
+ else:
117
+ # ## for AS
118
+ T = 64
119
+ F = 8
120
+ # ## for ESC
121
+ # T=32
122
+ # F=8
123
+ ## for SPC
124
+ # T=8
125
+ # F=8
126
+
127
+ # mask T
128
+ x = x.reshape(N, T, F, D)
129
+ len_keep_T = int(T * (1 - mask_t_prob))
130
+ noise = torch.rand(N, T, device=x.device) # noise in [0, 1]
131
+ # sort noise for each sample
132
+ ids_shuffle = torch.argsort(
133
+ noise, dim=1
134
+ ) # ascend: small is keep, large is remove
135
+ ids_keep = ids_shuffle[:, :len_keep_T]
136
+ index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D)
137
+ # x_masked = torch.gather(x, dim=1, index=index)
138
+ # x_masked = x_masked.reshape(N,len_keep_T*F,D)
139
+ x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D
140
+
141
+ # mask F
142
+ # x = x.reshape(N, T, F, D)
143
+ x = x.permute(0, 2, 1, 3) # N T' F D => N F T' D
144
+ len_keep_F = int(F * (1 - mask_f_prob))
145
+ noise = torch.rand(N, F, device=x.device) # noise in [0, 1]
146
+ # sort noise for each sample
147
+ ids_shuffle = torch.argsort(
148
+ noise, dim=1
149
+ ) # ascend: small is keep, large is remove
150
+ ids_keep = ids_shuffle[:, :len_keep_F]
151
+ # index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D)
152
+ index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D)
153
+ x_masked = torch.gather(x, dim=1, index=index)
154
+ x_masked = x_masked.permute(0, 2, 1, 3) # N F' T' D => N T' F' D
155
+ # x_masked = x_masked.reshape(N,len_keep*T,D)
156
+ x_masked = x_masked.reshape(N, len_keep_F * len_keep_T, D)
157
+
158
+ return x_masked, None, None
159
+
160
+ def forward_features_mask(self, x, mask_t_prob, mask_f_prob):
161
+ B = x.shape[0] # 4,1,1024,128
162
+ x = self.patch_embed(x) # 4, 512, 768
163
+
164
+ x = x + self.pos_embed[:, 1:, :]
165
+ if self.random_masking_2d:
166
+ x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob)
167
+ else:
168
+ x, mask, ids_restore = self.random_masking(x, mask_t_prob)
169
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
170
+ cls_tokens = cls_token.expand(B, -1, -1)
171
+ x = torch.cat((cls_tokens, x), dim=1)
172
+ x = self.pos_drop(x)
173
+
174
+ # apply Transformer blocks
175
+ for blk in self.blocks:
176
+ x = blk(x)
177
+
178
+ if self.global_pool:
179
+ x = x[:, 1:, :].mean(dim=1) # global pool without cls token
180
+ outcome = self.fc_norm(x)
181
+ else:
182
+ x = self.norm(x)
183
+ outcome = x[:, 0]
184
+
185
+ return outcome
186
+
187
+ # overwrite original timm
188
+ def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0):
189
+ if mask_t_prob > 0.0 or mask_f_prob > 0.0:
190
+ x = self.forward_features_mask(
191
+ x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob
192
+ )
193
+ else:
194
+ x = self.forward_features(x)
195
+ x = self.head(x)
196
+ return x
197
+
198
+
199
+ def vit_small_patch16(**kwargs):
200
+ model = VisionTransformer(
201
+ patch_size=16,
202
+ embed_dim=384,
203
+ depth=12,
204
+ num_heads=6,
205
+ mlp_ratio=4,
206
+ qkv_bias=True,
207
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
208
+ **kwargs
209
+ )
210
+ return model
211
+
212
+
213
+ def vit_base_patch16(**kwargs):
214
+ model = VisionTransformer(
215
+ patch_size=16,
216
+ embed_dim=768,
217
+ depth=12,
218
+ num_heads=12,
219
+ mlp_ratio=4,
220
+ qkv_bias=True,
221
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
222
+ **kwargs
223
+ )
224
+ return model
225
+
226
+
227
+ def vit_large_patch16(**kwargs):
228
+ model = VisionTransformer(
229
+ patch_size=16,
230
+ embed_dim=1024,
231
+ depth=24,
232
+ num_heads=16,
233
+ mlp_ratio=4,
234
+ qkv_bias=True,
235
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
236
+ **kwargs
237
+ )
238
+ return model
239
+
240
+
241
+ def vit_huge_patch14(**kwargs):
242
+ model = VisionTransformer(
243
+ patch_size=14,
244
+ embed_dim=1280,
245
+ depth=32,
246
+ num_heads=16,
247
+ mlp_ratio=4,
248
+ qkv_bias=True,
249
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
250
+ **kwargs
251
+ )
252
+ return model
audioldm_train/modules/audiomae/sequence_gen/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .sequence_input import Sequence2AudioMAE
2
+ from .model import CLAP2AudioMAE
audioldm_train/modules/audiomae/sequence_gen/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (264 Bytes). View file
 
audioldm_train/modules/audiomae/sequence_gen/__pycache__/model.cpython-310.pyc ADDED
Binary file (7.18 kB). View file
 
audioldm_train/modules/audiomae/sequence_gen/__pycache__/sequence_input.cpython-310.pyc ADDED
Binary file (13.6 kB). View file
 
audioldm_train/modules/audiomae/sequence_gen/model.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pytorch_lightning as pl
4
+ from audioldm_train.utilities.model_util import (
5
+ exists,
6
+ default,
7
+ mean_flat,
8
+ count_params,
9
+ instantiate_from_config,
10
+ )
11
+
12
+ from transformers import GPT2Config, GPT2Model
13
+ import torch.optim.lr_scheduler as lr_scheduler
14
+
15
+
16
+ class Prenet(nn.Module):
17
+ def __init__(self, in_dim, sizes=[256, 128], dropout_rate=0.5):
18
+ super(Prenet, self).__init__()
19
+ in_sizes = [in_dim] + sizes[:-1]
20
+ self.layers = nn.ModuleList(
21
+ [
22
+ nn.Linear(in_size, out_size)
23
+ for (in_size, out_size) in zip(in_sizes, sizes)
24
+ ]
25
+ )
26
+ self.relu = nn.ReLU()
27
+ self.dropout = nn.Dropout(dropout_rate)
28
+
29
+ def forward(self, inputs):
30
+ for linear in self.layers:
31
+ inputs = self.dropout(self.relu(linear(inputs)))
32
+ return inputs
33
+
34
+
35
+ class CLAP2AudioMAE(pl.LightningModule):
36
+ def __init__(
37
+ self,
38
+ sequence_gen_length,
39
+ base_learning_rate,
40
+ cond_stage_config,
41
+ use_audiomae_linear=False,
42
+ **kwargs
43
+ ):
44
+
45
+ super().__init__()
46
+ assert use_audiomae_linear == False
47
+ self.learning_rate = base_learning_rate
48
+ self.cond_stage_config = cond_stage_config
49
+ self.use_audiomae_linear = use_audiomae_linear
50
+
51
+ self.mae_token_num = sequence_gen_length # 4*4 pooling of the audiomae latent
52
+
53
+ self.cond_stage_models = nn.ModuleList([])
54
+ self.instantiate_cond_stage(cond_stage_config)
55
+
56
+ self.model = GPT2Model.from_pretrained("gpt2")
57
+
58
+ self.linear_clap = nn.Linear(512, 768)
59
+
60
+ if use_audiomae_linear:
61
+ # self.linear_audiomae = nn.Linear(768, 768) # TODO remove linear_audiomae
62
+ self.linear_audiomae = None # TODO remove linear_audiomae
63
+
64
+ self.loss_fn = nn.MSELoss()
65
+
66
+ self.logger_save_dir = None
67
+ self.logger_exp_name = None
68
+ self.logger_exp_group_name = None
69
+ self.logger_version = None
70
+
71
+ def set_log_dir(self, save_dir, exp_group_name, exp_name):
72
+ self.logger_save_dir = save_dir
73
+ self.logger_exp_group_name = exp_group_name
74
+ self.logger_exp_name = exp_name
75
+
76
+ def cfg_uncond(self, batch_size):
77
+ unconditional_conditioning = {}
78
+ for key in self.cond_stage_model_metadata:
79
+ model_idx = self.cond_stage_model_metadata[key]["model_idx"]
80
+ unconditional_conditioning[key] = self.cond_stage_models[
81
+ model_idx
82
+ ].get_unconditional_condition(batch_size)
83
+ assert (
84
+ "crossattn_audiomae_pooled" in unconditional_conditioning.keys()
85
+ ), "The module is not initialized with AudioMAE"
86
+ unconditional_conditioning[
87
+ "crossattn_clap_to_audiomae_feature"
88
+ ] = unconditional_conditioning["crossattn_audiomae_pooled"]
89
+ return unconditional_conditioning
90
+
91
+ def configure_optimizers(self):
92
+ lr = float(self.learning_rate)
93
+ params = list(self.model.parameters()) + list(self.linear_clap.parameters())
94
+
95
+ if self.use_audiomae_linear:
96
+ params += list(self.linear_audiomae.parameters())
97
+
98
+ opt = torch.optim.AdamW(params, lr=lr)
99
+ scheduler = lr_scheduler.StepLR(opt, step_size=1, gamma=0.9)
100
+ return [opt], [scheduler]
101
+
102
+ def training_step(self, batch, batch_idx=None, cond_dict=None):
103
+ if cond_dict is None:
104
+ cond_dict = self.get_input(batch)
105
+
106
+ input_embeds, target_embeds = (
107
+ cond_dict["film_clap_cond1"],
108
+ cond_dict["crossattn_audiomae_pooled"][0],
109
+ )
110
+
111
+ # Some times if the pooling factor is random, the length of crossattn_audiomae_pooled is not necessary 32, so need to calculate separately
112
+ if "crossattn_audiomae_pooled_44" in cond_dict.keys():
113
+ target_embeds = cond_dict["crossattn_audiomae_pooled_44"][0]
114
+
115
+ if self.use_audiomae_linear:
116
+ input_embeds = torch.cat(
117
+ [self.linear_clap(input_embeds), self.linear_audiomae(target_embeds)],
118
+ dim=1,
119
+ )
120
+ else:
121
+ input_embeds = torch.cat(
122
+ [self.linear_clap(input_embeds), target_embeds], dim=1
123
+ )
124
+
125
+ output_embeds = self.model(inputs_embeds=input_embeds)["last_hidden_state"]
126
+
127
+ target = target_embeds
128
+ output = output_embeds[:, :-1]
129
+
130
+ loss = self.loss_fn(output, target)
131
+
132
+ self.log(
133
+ "train/loss_clap_2_audiomae",
134
+ loss,
135
+ prog_bar=True,
136
+ logger=True,
137
+ on_step=True,
138
+ on_epoch=False,
139
+ sync_dist=True,
140
+ )
141
+
142
+ self.log(
143
+ "global_step_audiomae",
144
+ float(self.global_step),
145
+ prog_bar=True,
146
+ logger=True,
147
+ on_step=True,
148
+ on_epoch=False,
149
+ sync_dist=True,
150
+ )
151
+
152
+ return loss
153
+
154
+ def generate(self, batch, cond_dict=None, no_grad=False):
155
+ if cond_dict is None:
156
+ cond_dict = self.get_input(batch)
157
+ input_embeds = cond_dict["film_clap_cond1"]
158
+ steps = self.mae_token_num
159
+
160
+ if no_grad:
161
+ with torch.no_grad():
162
+ model_input = self.linear_clap(input_embeds)
163
+ for _ in range(steps):
164
+ output = self.model(inputs_embeds=model_input)["last_hidden_state"]
165
+ model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
166
+ else:
167
+ model_input = self.linear_clap(input_embeds)
168
+ for _ in range(steps):
169
+ output = self.model(inputs_embeds=model_input)["last_hidden_state"]
170
+ model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
171
+
172
+ return model_input[:, 1:], cond_dict
173
+
174
+ # def on_validation_epoch_start(self) -> None:
175
+ # # Use text as condition during validation
176
+ # for key in self.cond_stage_model_metadata.keys():
177
+ # metadata = self.cond_stage_model_metadata[key]
178
+ # model_idx, cond_stage_key, conditioning_key = metadata["model_idx"], metadata["cond_stage_key"], metadata["conditioning_key"]
179
+
180
+ # # If we use CLAP as condition, we might use audio for training, but we also must use text for evaluation
181
+ # # if(isinstance(self.cond_stage_models[model_idx], CLAPAudioEmbeddingClassifierFreev2)):
182
+ # # self.cond_stage_model_metadata[key]["cond_stage_key_orig"] = self.cond_stage_model_metadata[key]["cond_stage_key"]
183
+ # # self.cond_stage_model_metadata[key]["embed_mode_orig"] = self.cond_stage_models[model_idx].embed_mode
184
+ # # print("Change the model original cond_keyand embed_mode %s, %s to text during evaluation" % (self.cond_stage_model_metadata[key]["cond_stage_key_orig"], self.cond_stage_model_metadata[key]["embed_mode_orig"]))
185
+ # # self.cond_stage_model_metadata[key]["cond_stage_key"] = "text"
186
+ # # self.cond_stage_models[model_idx].embed_mode = "text"
187
+
188
+ # return super().on_validation_epoch_start()
189
+
190
+ def validation_step(self, batch, batch_idx):
191
+ cond_dict = self.get_input(batch)
192
+ # cond_dict['film_clap_cond1']: [2,1,512]
193
+ # cond_dict['crossattn_audiomae_pooled']: [2, 128, 768]
194
+
195
+ input_embeds, target_embeds = (
196
+ cond_dict["film_clap_cond1"],
197
+ cond_dict["crossattn_audiomae_pooled"][0],
198
+ )
199
+
200
+ # Some times if the pooling factor is random, the length of crossattn_audiomae_pooled is not necessary 32, so need to calculate separately
201
+ if "crossattn_audiomae_pooled_44" in cond_dict.keys():
202
+ target_embeds = cond_dict["crossattn_audiomae_pooled_44"][0]
203
+
204
+ if self.use_audiomae_linear:
205
+ input_embeds = torch.cat(
206
+ [self.linear_clap(input_embeds), self.linear_audiomae(target_embeds)],
207
+ dim=1,
208
+ )
209
+ else:
210
+ input_embeds = torch.cat(
211
+ [self.linear_clap(input_embeds), target_embeds], dim=1
212
+ )
213
+
214
+ output_embeds = self.model(inputs_embeds=input_embeds)["last_hidden_state"]
215
+
216
+ target = target_embeds
217
+ output = output_embeds[:, :-1]
218
+
219
+ loss = self.loss_fn(output, target)
220
+
221
+ self.log(
222
+ "val/loss",
223
+ loss,
224
+ prog_bar=True,
225
+ logger=True,
226
+ on_step=True,
227
+ sync_dist=True,
228
+ on_epoch=True,
229
+ )
230
+
231
+ generation_output, _ = self.generate(batch)
232
+ ar_gen_loss = self.loss_fn(generation_output, target)
233
+
234
+ self.log(
235
+ "val/ar_gen_loss",
236
+ ar_gen_loss,
237
+ prog_bar=True,
238
+ logger=True,
239
+ on_step=True,
240
+ sync_dist=True,
241
+ on_epoch=True,
242
+ )
243
+
244
+ return {"loss": loss, "ar_gen_loss": ar_gen_loss}
245
+
246
+ def get_input_item(self, batch, k):
247
+ fname, text, label_indices, waveform, stft, fbank = (
248
+ batch["fname"],
249
+ batch["text"],
250
+ batch["label_vector"],
251
+ batch["waveform"],
252
+ batch["stft"],
253
+ batch["log_mel_spec"],
254
+ )
255
+ ret = {}
256
+
257
+ ret["fbank"] = (
258
+ fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float()
259
+ )
260
+ ret["stft"] = stft.to(memory_format=torch.contiguous_format).float()
261
+ # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float()
262
+ ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float()
263
+ ret["text"] = list(text)
264
+ ret["fname"] = fname
265
+
266
+ for key in batch.keys():
267
+ if key not in ret.keys():
268
+ ret[key] = batch[key]
269
+
270
+ return ret[k]
271
+
272
+ def get_input(self, batch):
273
+ cond_dict = {}
274
+ if len(self.cond_stage_model_metadata.keys()) > 0:
275
+ unconditional_cfg = False
276
+
277
+ for cond_model_key in self.cond_stage_model_metadata.keys():
278
+ cond_stage_key = self.cond_stage_model_metadata[cond_model_key][
279
+ "cond_stage_key"
280
+ ]
281
+
282
+ # if(not self.training):
283
+ # if(isinstance(self.cond_stage_models[self.cond_stage_model_metadata[cond_model_key]["model_idx"]], CLAPAudioEmbeddingClassifierFreev2)):
284
+ # assert cond_stage_key == "text" # CLAP model should use text for evaluation
285
+
286
+ # The original data for conditioning
287
+ xc = self.get_input_item(batch, cond_stage_key)
288
+ if type(xc) == torch.Tensor:
289
+ xc = xc.to(self.device)
290
+
291
+ c = self.get_learned_conditioning(
292
+ xc, key=cond_model_key, unconditional_cfg=unconditional_cfg
293
+ )
294
+ cond_dict[cond_model_key] = c
295
+
296
+ return cond_dict
297
+
298
+ def instantiate_cond_stage(self, config):
299
+ self.cond_stage_model_metadata = {}
300
+
301
+ for i, cond_model_key in enumerate(config.keys()):
302
+ model = instantiate_from_config(config[cond_model_key])
303
+ self.cond_stage_models.append(model)
304
+ self.cond_stage_model_metadata[cond_model_key] = {
305
+ "model_idx": i,
306
+ "cond_stage_key": config[cond_model_key]["cond_stage_key"],
307
+ "conditioning_key": config[cond_model_key]["conditioning_key"],
308
+ }
309
+
310
+ def get_learned_conditioning(self, c, key, unconditional_cfg):
311
+ assert key in self.cond_stage_model_metadata.keys()
312
+
313
+ # Classifier-free guidance
314
+ if not unconditional_cfg:
315
+ c = self.cond_stage_models[
316
+ self.cond_stage_model_metadata[key]["model_idx"]
317
+ ](c)
318
+ else:
319
+ if isinstance(c, torch.Tensor):
320
+ batchsize = c.size(0)
321
+ elif isinstance(c, list):
322
+ batchsize = len(c)
323
+ else:
324
+ raise NotImplementedError()
325
+ c = self.cond_stage_models[
326
+ self.cond_stage_model_metadata[key]["model_idx"]
327
+ ].get_unconditional_condition(batchsize)
328
+
329
+ return c
audioldm_train/modules/audiomae/sequence_gen/sequence_input.py ADDED
@@ -0,0 +1,737 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import pytorch_lightning as pl
5
+ from audioldm_train.utilities.model_util import (
6
+ exists,
7
+ default,
8
+ mean_flat,
9
+ count_params,
10
+ instantiate_from_config,
11
+ )
12
+ from torch.optim import *
13
+
14
+ from transformers import GPT2Config, GPT2Model, GPTJConfig, GPTJModel
15
+ import torch.optim.lr_scheduler as lr_scheduler
16
+
17
+
18
+ class Sequence2AudioMAE(pl.LightningModule):
19
+ def __init__(
20
+ self,
21
+ base_learning_rate,
22
+ sequence_gen_length,
23
+ sequence_input_key,
24
+ sequence_input_embed_dim,
25
+ cond_stage_config,
26
+ optimizer_type="AdamW",
27
+ use_warmup=True,
28
+ use_ar_gen_loss=False,
29
+ use_audiomae_linear=False,
30
+ target_tokens_mask_ratio=0.0,
31
+ random_mask_ratio=False,
32
+ **kwargs
33
+ ):
34
+
35
+ super().__init__()
36
+ assert use_audiomae_linear == False
37
+ self.random_mask_ratio = random_mask_ratio
38
+ self.learning_rate = base_learning_rate
39
+ self.cond_stage_config = cond_stage_config
40
+ self.use_audiomae_linear = use_audiomae_linear
41
+ self.optimizer_type = optimizer_type
42
+ self.use_warmup = use_warmup
43
+ self.use_ar_gen_loss = use_ar_gen_loss
44
+ # Even though the LDM can be conditioned on mutliple pooling rate
45
+ # Our model always predict the higest pooling rate
46
+
47
+ self.mae_token_num = sequence_gen_length
48
+ self.sequence_input_key = sequence_input_key
49
+ self.sequence_input_embed_dim = sequence_input_embed_dim
50
+ self.target_tokens_mask_ratio = target_tokens_mask_ratio
51
+
52
+ self.start_of_sequence_tokens = nn.Embedding(32, 768)
53
+ self.end_of_sequence_tokens = nn.Embedding(32, 768)
54
+
55
+ self.input_sequence_embed_linear = nn.ModuleList([])
56
+ self.initial_learning_rate = None
57
+
58
+ for dim in self.sequence_input_embed_dim:
59
+ self.input_sequence_embed_linear.append(nn.Linear(dim, 768))
60
+
61
+ self.cond_stage_models = nn.ModuleList([])
62
+ self.instantiate_cond_stage(cond_stage_config)
63
+ self.initialize_param_check_toolkit()
64
+
65
+ self.private_training_step = 0
66
+
67
+ # configuration = GPT2Config(n_layer=1) # TODO
68
+ # self.model=GPT2Model(configuration)
69
+ ###################
70
+ # self.model=nn.Linear(768,768, bias=False) # TODO change the model
71
+ # with torch.no_grad():
72
+ # self.model.weight.copy_(torch.eye(768))
73
+ ###################
74
+ self.model = GPT2Model.from_pretrained("gpt2")
75
+ ###################
76
+ # self.model = nn.LSTM(input_size=768, hidden_size=768, num_layers=1,bias=False) # TODO
77
+
78
+ # self.loss_fn = nn.MSELoss()
79
+ self.loss_fn = nn.L1Loss()
80
+
81
+ self.logger_save_dir = None
82
+ self.logger_exp_name = None
83
+ self.logger_exp_group_name = None
84
+ self.logger_version = None
85
+
86
+ def set_log_dir(self, save_dir, exp_group_name, exp_name):
87
+ self.logger_save_dir = save_dir
88
+ self.logger_exp_group_name = exp_group_name
89
+ self.logger_exp_name = exp_name
90
+
91
+ def cfg_uncond(self, batch_size):
92
+ unconditional_conditioning = {}
93
+ for key in self.cond_stage_model_metadata:
94
+ model_idx = self.cond_stage_model_metadata[key]["model_idx"]
95
+ unconditional_conditioning[key] = self.cond_stage_models[
96
+ model_idx
97
+ ].get_unconditional_condition(batch_size)
98
+ assert (
99
+ "crossattn_audiomae_pooled" in unconditional_conditioning.keys()
100
+ ), "The module is not initialized with AudioMAE"
101
+ unconditional_conditioning[
102
+ "crossattn_clap_to_audiomae_feature"
103
+ ] = unconditional_conditioning["crossattn_audiomae_pooled"]
104
+ return unconditional_conditioning
105
+
106
+ def configure_optimizers(self):
107
+ lr = float(self.learning_rate)
108
+ # params = list(self.model.parameters()) + list(self.input_sequence_embed_linear.parameters())
109
+ params = list(self.parameters())
110
+
111
+ # opt = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.98), eps=1e-9)
112
+ opt = eval(self.optimizer_type)(params, lr=lr)
113
+ scheduler = lr_scheduler.StepLR(opt, step_size=10, gamma=0.8)
114
+ return [opt], [scheduler]
115
+
116
+ def add_sos_eos_tokens(self, _id, sequence, attn_mask):
117
+ batchsize = sequence.size(0)
118
+
119
+ new_attn_mask_step = torch.ones((batchsize, 1)).to(sequence.device)
120
+ key_id = torch.tensor([_id]).to(sequence.device)
121
+
122
+ # Add two more steps to attn mask
123
+ new_attn_mask = torch.cat(
124
+ [new_attn_mask_step, attn_mask, new_attn_mask_step], dim=1
125
+ )
126
+
127
+ # Add two more tokens in the sequence
128
+ sos_token = self.start_of_sequence_tokens(key_id).expand(batchsize, 1, -1)
129
+ eos_token = self.end_of_sequence_tokens(key_id).expand(batchsize, 1, -1)
130
+ new_sequence = torch.cat([sos_token, sequence, eos_token], dim=1)
131
+ return new_sequence, new_attn_mask
132
+
133
+ def truncate_sequence_and_mask(self, sequence, mask, max_len=512):
134
+ if sequence.size(1) > max_len:
135
+ print(
136
+ "The input sequence length to GPT-2 model is too long:",
137
+ sequence.size(1),
138
+ )
139
+ return sequence[:, :max_len], mask[:, :max_len]
140
+ else:
141
+ return sequence, mask
142
+
143
+ def get_input_sequence_and_mask(self, cond_dict):
144
+ input_embeds = None
145
+ input_embeds_attn_mask = None
146
+ for _id, sequence_key in enumerate(self.sequence_input_key):
147
+ assert sequence_key in cond_dict.keys(), (
148
+ "Invalid sequence key %s" % sequence_key
149
+ )
150
+ cond_embed = cond_dict[sequence_key]
151
+ if isinstance(cond_embed, list):
152
+ assert (
153
+ len(cond_embed) == 2
154
+ ), "The crossattn returned list should have length 2, including embed and attn_mask"
155
+ item_input_embeds, item_attn_mask = cond_embed
156
+
157
+ item_input_embeds = self.input_sequence_embed_linear[_id](
158
+ item_input_embeds
159
+ )
160
+
161
+ item_input_embeds, item_attn_mask = self.add_sos_eos_tokens(
162
+ _id, item_input_embeds, item_attn_mask
163
+ )
164
+
165
+ if input_embeds is None and input_embeds_attn_mask is None:
166
+ input_embeds, input_embeds_attn_mask = (
167
+ item_input_embeds,
168
+ item_attn_mask,
169
+ )
170
+ else:
171
+ input_embeds = torch.cat(
172
+ [input_embeds, item_input_embeds], dim=1
173
+ ) # The 1-st dimension is time steps
174
+ input_embeds_attn_mask = torch.cat(
175
+ [input_embeds_attn_mask, item_attn_mask], dim=1
176
+ ) # The 1-st dimension is time steps
177
+ else:
178
+ assert isinstance(cond_embed, torch.Tensor)
179
+ cond_embed = self.input_sequence_embed_linear[_id](cond_embed)
180
+ attn_mask = torch.ones((cond_embed.size(0), cond_embed.size(1))).to(
181
+ cond_embed.device
182
+ )
183
+
184
+ item_input_embeds, item_attn_mask = self.add_sos_eos_tokens(
185
+ _id, cond_embed, attn_mask
186
+ )
187
+
188
+ if input_embeds is None and input_embeds_attn_mask is None:
189
+ input_embeds, input_embeds_attn_mask = (
190
+ item_input_embeds,
191
+ item_attn_mask,
192
+ )
193
+ else:
194
+ input_embeds, input_embeds_attn_mask = torch.cat(
195
+ [input_embeds, item_input_embeds], dim=1
196
+ ), torch.cat([input_embeds_attn_mask, item_attn_mask], dim=1)
197
+
198
+ assert input_embeds is not None and input_embeds_attn_mask is not None
199
+
200
+ input_embeds, input_embeds_attn_mask = self.truncate_sequence_and_mask(
201
+ input_embeds, input_embeds_attn_mask, int(1024 - self.mae_token_num)
202
+ )
203
+ cond_sequence_end_time_idx = input_embeds.size(
204
+ 1
205
+ ) # The index that we start to collect the output embeds
206
+
207
+ return input_embeds, input_embeds_attn_mask, cond_sequence_end_time_idx
208
+
209
+ def warmup_step(self):
210
+ if self.initial_learning_rate is None:
211
+ self.initial_learning_rate = float(self.learning_rate)
212
+
213
+ # Only the first parameter group
214
+ if self.global_step <= 1000:
215
+ if self.global_step == 0:
216
+ print(
217
+ "Warming up learning rate start with %s"
218
+ % self.initial_learning_rate
219
+ )
220
+ self.trainer.optimizers[0].param_groups[0]["lr"] = (
221
+ self.global_step / 1000
222
+ ) * self.initial_learning_rate
223
+ else:
224
+ # TODO set learning rate here
225
+ self.trainer.optimizers[0].param_groups[0][
226
+ "lr"
227
+ ] = self.initial_learning_rate
228
+
229
+ def mask_target_sequence(self, target_embeds, target_embeds_attn_mask):
230
+ time_seq_mask = None
231
+ if self.target_tokens_mask_ratio > 1e-4:
232
+ batchsize, time_seq_len, embed_dim = target_embeds.size()
233
+ _, time_seq_len = target_embeds_attn_mask.size()
234
+ # Generate random mask
235
+ if self.random_mask_ratio:
236
+ mask_ratio = torch.rand(1).item() * self.target_tokens_mask_ratio
237
+ else:
238
+ mask_ratio = self.target_tokens_mask_ratio
239
+
240
+ time_seq_mask = (torch.rand((batchsize, time_seq_len)) > mask_ratio).to(
241
+ target_embeds.device
242
+ )
243
+ # Mask the target embedding
244
+ target_embeds = target_embeds * time_seq_mask.unsqueeze(-1)
245
+ target_embeds_attn_mask = target_embeds_attn_mask * time_seq_mask
246
+ return target_embeds, target_embeds_attn_mask, time_seq_mask
247
+
248
+ def training_step(self, batch, batch_idx=None, cond_dict=None, return_output=False):
249
+ # cond_dict['film_clap_cond1']: [2,1,512]
250
+ # cond_dict['crossattn_audiomae_pooled']: [2, 128, 768]
251
+
252
+ if self.use_warmup:
253
+ self.warmup_step()
254
+
255
+ if cond_dict is None:
256
+ cond_dict = self.get_input(batch)
257
+
258
+ # param_list = list(self.model.parameters())
259
+ target_embeds, target_embeds_attn_mask = (
260
+ cond_dict["crossattn_audiomae_pooled"][0],
261
+ cond_dict["crossattn_audiomae_pooled"][1],
262
+ )
263
+
264
+ (
265
+ input_embeds,
266
+ input_embeds_attn_mask,
267
+ cond_sequence_end_time_idx,
268
+ ) = self.get_input_sequence_and_mask(cond_dict)
269
+
270
+ # Some times if the pooling factor is random, the length of crossattn_audiomae_pooled is not necessary 32, so need to calculate separately
271
+ if "crossattn_audiomae_pooled_44" in cond_dict.keys():
272
+ target_embeds = cond_dict["crossattn_audiomae_pooled_44"][0]
273
+
274
+ # target_embeds, target_embeds_attn_mask, time_seq_mask = self.mask_target_sequence(target_embeds, target_embeds_attn_mask)
275
+
276
+ final_input_embeds = torch.cat([input_embeds, target_embeds], dim=1)
277
+ final_input_embeds_attn_mask = torch.cat(
278
+ [input_embeds_attn_mask, target_embeds_attn_mask], dim=1
279
+ )
280
+
281
+ ########################### GPT-2
282
+ output_embeds = self.model(
283
+ inputs_embeds=final_input_embeds,
284
+ attention_mask=final_input_embeds_attn_mask,
285
+ )["last_hidden_state"]
286
+ ########################### DNN
287
+ # output_embeds = self.model(final_input_embeds)
288
+ ########################### LSTM
289
+ # output_embeds,_ = self.model(final_input_embeds)
290
+
291
+ target = target_embeds
292
+ output = output_embeds[:, cond_sequence_end_time_idx - 1 : -1]
293
+
294
+ # output = output_embeds[:, cond_sequence_end_time_idx: ] # TODO bug here intentionally
295
+
296
+ assert target.size(1) == self.mae_token_num
297
+
298
+ # if(batch_idx % 1000 == 0):
299
+ # print(output[0], target[0])
300
+ loss = self.loss_fn(output, target)
301
+
302
+ if self.use_ar_gen_loss:
303
+ ar_gen_loss = self.calculate_ahead_k_step_loss(batch, batch_idx, cond_dict)
304
+ else:
305
+ ar_gen_loss = loss
306
+
307
+ if self.private_training_step % 500 == 0:
308
+ print(
309
+ "AudioMAE prediction module:", "loss", loss, "ar_gen_loss", ar_gen_loss
310
+ )
311
+
312
+ try:
313
+ learning_rate = self.trainer.optimizers[0].param_groups[0]["lr"]
314
+
315
+ self.log(
316
+ "train/lr_audiomae_pred",
317
+ learning_rate,
318
+ prog_bar=True,
319
+ logger=True,
320
+ on_step=True,
321
+ on_epoch=False,
322
+ sync_dist=True,
323
+ )
324
+ except:
325
+ pass
326
+
327
+ self.log(
328
+ "train/loss_clap_2_audiomae",
329
+ loss,
330
+ prog_bar=True,
331
+ logger=True,
332
+ on_step=True,
333
+ on_epoch=False,
334
+ sync_dist=True,
335
+ )
336
+
337
+ self.log(
338
+ "train/loss_ar_gen_loss",
339
+ ar_gen_loss,
340
+ prog_bar=True,
341
+ logger=True,
342
+ on_step=True,
343
+ on_epoch=False,
344
+ sync_dist=True,
345
+ )
346
+
347
+ self.log(
348
+ "global_step_audiomae",
349
+ float(self.global_step),
350
+ prog_bar=True,
351
+ logger=True,
352
+ on_step=True,
353
+ on_epoch=False,
354
+ sync_dist=True,
355
+ )
356
+ self.private_training_step += 1
357
+ if return_output:
358
+ return loss + ar_gen_loss, output
359
+ else:
360
+ return loss + ar_gen_loss
361
+
362
+ def calculate_ahead_k_step_loss(self, batch, batch_idx=None, cond_dict=None):
363
+ if cond_dict is None:
364
+ cond_dict = self.get_input(batch)
365
+
366
+ target_embeds, target_embeds_attn_mask = (
367
+ cond_dict["crossattn_audiomae_pooled"][0],
368
+ cond_dict["crossattn_audiomae_pooled"][1],
369
+ )
370
+
371
+ assert (
372
+ torch.sum(target_embeds_attn_mask < 0.1) < 1
373
+ ), "This function only works for AudioMAE prediction, which should have all one atten_mask"
374
+
375
+ (
376
+ input_embeds,
377
+ input_embeds_attn_mask,
378
+ cond_sequence_end_time_idx,
379
+ ) = self.get_input_sequence_and_mask(cond_dict)
380
+
381
+ target_total_time_steps = target_embeds.size(1)
382
+
383
+ steps = min(round(torch.rand(1).item() * 8), target_total_time_steps)
384
+
385
+ if steps < 2:
386
+ steps = 2
387
+
388
+ start_idx = max(
389
+ 0, round(torch.rand(1).item() * (target_total_time_steps - steps)) - 1
390
+ )
391
+
392
+ model_input = input_embeds
393
+ model_input_mask = input_embeds_attn_mask
394
+ target_embeds_ar_gen = target_embeds[:, start_idx : start_idx + steps, :]
395
+ generation = []
396
+
397
+ if start_idx > 0:
398
+ model_input = torch.cat(
399
+ [input_embeds, target_embeds[:, :start_idx, :]], dim=1
400
+ )
401
+ attention_mask_known_steps = torch.ones(
402
+ (model_input_mask.size(0), start_idx)
403
+ ).to(model_input.device)
404
+ model_input_mask = torch.cat(
405
+ [input_embeds_attn_mask, attention_mask_known_steps], dim=1
406
+ )
407
+
408
+ for _ in range(steps):
409
+ output = self.model(
410
+ inputs_embeds=model_input, attention_mask=model_input_mask
411
+ )["last_hidden_state"]
412
+ # Update the model input
413
+ generation.append(output[:, -1:, :])
414
+ model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
415
+ # Update the attention mask
416
+ attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to(
417
+ model_input.device
418
+ )
419
+ model_input_mask = torch.cat(
420
+ [model_input_mask, attention_mask_new_step], dim=1
421
+ )
422
+
423
+ generation = torch.cat(generation, dim=1)
424
+
425
+ return self.loss_fn(generation, target_embeds_ar_gen)
426
+
427
+ def generate_partial(self, batch, cond_dict=None, no_grad=False):
428
+ if cond_dict is None:
429
+ cond_dict = self.get_input(batch)
430
+
431
+ print("Generate partially prompted audio with in-context learning")
432
+ # self.model.train()
433
+ # assert self.model.training==True
434
+
435
+ target_embeds, target_embeds_attn_mask = (
436
+ cond_dict["crossattn_audiomae_pooled"][0],
437
+ cond_dict["crossattn_audiomae_pooled"][1],
438
+ )
439
+
440
+ target_time_steps = target_embeds.size(1)
441
+
442
+ (
443
+ input_embeds,
444
+ input_embeds_attn_mask,
445
+ cond_sequence_end_time_idx,
446
+ ) = self.get_input_sequence_and_mask(cond_dict)
447
+
448
+ model_input = torch.cat(
449
+ [input_embeds, target_embeds[:, : target_time_steps // 4, :]], dim=1
450
+ )
451
+ model_input_mask = torch.cat(
452
+ [
453
+ input_embeds_attn_mask,
454
+ target_embeds_attn_mask[:, : target_time_steps // 4],
455
+ ],
456
+ dim=1,
457
+ )
458
+
459
+ steps = self.mae_token_num
460
+
461
+ for _ in range(3 * steps // 4):
462
+ output = self.model(
463
+ inputs_embeds=model_input, attention_mask=model_input_mask
464
+ )["last_hidden_state"]
465
+ # Update the model input
466
+ model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
467
+ # Update the attention mask
468
+ attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to(
469
+ model_input.device
470
+ )
471
+ model_input_mask = torch.cat(
472
+ [model_input_mask, attention_mask_new_step], dim=1
473
+ )
474
+
475
+ output = model_input[:, cond_sequence_end_time_idx:]
476
+
477
+ return output, cond_dict
478
+
479
+ def generate(self, batch, cond_dict=None, no_grad=False):
480
+ if cond_dict is None:
481
+ cond_dict = self.get_input(batch)
482
+
483
+ # self.model.train()
484
+ # print("!!!!!!!!!!!!!train")
485
+
486
+ (
487
+ input_embeds,
488
+ input_embeds_attn_mask,
489
+ cond_sequence_end_time_idx,
490
+ ) = self.get_input_sequence_and_mask(cond_dict)
491
+ model_input = input_embeds
492
+ model_input_mask = input_embeds_attn_mask
493
+
494
+ steps = self.mae_token_num
495
+
496
+ for _ in range(steps):
497
+ output = self.model(
498
+ inputs_embeds=model_input, attention_mask=model_input_mask
499
+ )["last_hidden_state"]
500
+ # Update the model input
501
+ model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
502
+ # Update the attention mask
503
+ attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to(
504
+ model_input.device
505
+ )
506
+ model_input_mask = torch.cat(
507
+ [model_input_mask, attention_mask_new_step], dim=1
508
+ )
509
+
510
+ return model_input[:, cond_sequence_end_time_idx:], cond_dict
511
+
512
+ # def on_validation_epoch_start(self) -> None:
513
+ # # Use text as condition during validation
514
+ # for key in self.cond_stage_model_metadata.keys():
515
+ # metadata = self.cond_stage_model_metadata[key]
516
+ # model_idx, cond_stage_key, conditioning_key = metadata["model_idx"], metadata["cond_stage_key"], metadata["conditioning_key"]
517
+
518
+ # # If we use CLAP as condition, we might use audio for training, but we also must use text for evaluation
519
+ # # if(isinstance(self.cond_stage_models[model_idx], CLAPAudioEmbeddingClassifierFreev2)):
520
+ # # self.cond_stage_model_metadata[key]["cond_stage_key_orig"] = self.cond_stage_model_metadata[key]["cond_stage_key"]
521
+ # # self.cond_stage_model_metadata[key]["embed_mode_orig"] = self.cond_stage_models[model_idx].embed_mode
522
+ # # print("Change the model original cond_keyand embed_mode %s, %s to text during evaluation" % (self.cond_stage_model_metadata[key]["cond_stage_key_orig"], self.cond_stage_model_metadata[key]["embed_mode_orig"]))
523
+ # # self.cond_stage_model_metadata[key]["cond_stage_key"] = "text"
524
+ # # self.cond_stage_models[model_idx].embed_mode = "text"
525
+
526
+ # return super().on_validation_epoch_start()
527
+
528
+ def validation_step(self, batch, batch_idx):
529
+ cond_dict = self.get_input(batch)
530
+ # cond_dict['film_clap_cond1']: [2,1,512]
531
+ # cond_dict['crossattn_audiomae_pooled']: [2, 128, 768]
532
+
533
+ target_embeds, target_embeds_attn_mask = (
534
+ cond_dict["crossattn_audiomae_pooled"][0],
535
+ cond_dict["crossattn_audiomae_pooled"][1],
536
+ )
537
+
538
+ (
539
+ input_embeds,
540
+ input_embeds_attn_mask,
541
+ cond_sequence_end_time_idx,
542
+ ) = self.get_input_sequence_and_mask(cond_dict)
543
+
544
+ # Some times if the pooling factor is random, the length of crossattn_audiomae_pooled is not necessary 32, so need to calculate separately
545
+ if "crossattn_audiomae_pooled_44" in cond_dict.keys():
546
+ target_embeds = cond_dict["crossattn_audiomae_pooled_44"][0]
547
+
548
+ final_input_embeds = torch.cat([input_embeds, target_embeds], dim=1)
549
+ final_input_embeds_attn_mask = torch.cat(
550
+ [input_embeds_attn_mask, target_embeds_attn_mask], dim=1
551
+ )
552
+
553
+ output_embeds = self.model(
554
+ inputs_embeds=final_input_embeds,
555
+ attention_mask=final_input_embeds_attn_mask,
556
+ )["last_hidden_state"]
557
+
558
+ target = target_embeds
559
+ output = output_embeds[:, cond_sequence_end_time_idx - 1 : -1]
560
+
561
+ loss = self.loss_fn(output, target)
562
+
563
+ self.log(
564
+ "val/loss",
565
+ loss,
566
+ prog_bar=True,
567
+ logger=True,
568
+ on_step=True,
569
+ sync_dist=True,
570
+ on_epoch=True,
571
+ )
572
+
573
+ generation_output, _ = self.generate(batch)
574
+ ar_gen_loss = self.loss_fn(generation_output, target)
575
+
576
+ self.log(
577
+ "val/ar_gen_loss",
578
+ ar_gen_loss,
579
+ prog_bar=True,
580
+ logger=True,
581
+ on_step=True,
582
+ sync_dist=True,
583
+ on_epoch=True,
584
+ )
585
+
586
+ return {"loss": loss, "ar_gen_loss": ar_gen_loss}
587
+
588
+ def get_input_item(self, batch, k):
589
+ fname, text, label_indices, waveform, stft, fbank = (
590
+ batch["fname"],
591
+ batch["text"],
592
+ batch["label_vector"],
593
+ batch["waveform"],
594
+ batch["stft"],
595
+ batch["log_mel_spec"],
596
+ )
597
+ ret = {}
598
+
599
+ ret["fbank"] = (
600
+ fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float()
601
+ )
602
+ ret["stft"] = stft.to(memory_format=torch.contiguous_format).float()
603
+ # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float()
604
+ ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float()
605
+ ret["text"] = list(text)
606
+ ret["fname"] = fname
607
+
608
+ for key in batch.keys():
609
+ if key not in ret.keys():
610
+ ret[key] = batch[key]
611
+
612
+ return ret[k]
613
+
614
+ def get_input(self, batch):
615
+ cond_dict = {}
616
+ if len(self.cond_stage_model_metadata.keys()) > 0:
617
+ unconditional_cfg = False
618
+
619
+ for cond_model_key in self.cond_stage_model_metadata.keys():
620
+ cond_stage_key = self.cond_stage_model_metadata[cond_model_key][
621
+ "cond_stage_key"
622
+ ]
623
+
624
+ # if(not self.training):
625
+ # if(isinstance(self.cond_stage_models[self.cond_stage_model_metadata[cond_model_key]["model_idx"]], CLAPAudioEmbeddingClassifierFreev2)):
626
+ # assert cond_stage_key == "text" # CLAP model should use text for evaluation
627
+
628
+ # The original data for conditioning
629
+ xc = self.get_input_item(batch, cond_stage_key)
630
+ if type(xc) == torch.Tensor:
631
+ xc = xc.to(self.device)
632
+
633
+ c = self.get_learned_conditioning(
634
+ xc, key=cond_model_key, unconditional_cfg=unconditional_cfg
635
+ )
636
+ cond_dict[cond_model_key] = c
637
+
638
+ return cond_dict
639
+
640
+ def instantiate_cond_stage(self, config):
641
+ self.cond_stage_model_metadata = {}
642
+
643
+ for i, cond_model_key in enumerate(config.keys()):
644
+ model = instantiate_from_config(config[cond_model_key])
645
+ self.cond_stage_models.append(model)
646
+ self.cond_stage_model_metadata[cond_model_key] = {
647
+ "model_idx": i,
648
+ "cond_stage_key": config[cond_model_key]["cond_stage_key"],
649
+ "conditioning_key": config[cond_model_key]["conditioning_key"],
650
+ }
651
+
652
+ def get_learned_conditioning(self, c, key, unconditional_cfg):
653
+ assert key in self.cond_stage_model_metadata.keys()
654
+
655
+ # Classifier-free guidance
656
+ if not unconditional_cfg:
657
+ c = self.cond_stage_models[
658
+ self.cond_stage_model_metadata[key]["model_idx"]
659
+ ](c)
660
+ else:
661
+ if isinstance(c, torch.Tensor):
662
+ batchsize = c.size(0)
663
+ elif isinstance(c, list):
664
+ batchsize = len(c)
665
+ else:
666
+ raise NotImplementedError()
667
+ c = self.cond_stage_models[
668
+ self.cond_stage_model_metadata[key]["model_idx"]
669
+ ].get_unconditional_condition(batchsize)
670
+
671
+ return c
672
+
673
+ def initialize_param_check_toolkit(self):
674
+ self.tracked_steps = 0
675
+ self.param_dict = {}
676
+
677
+ def statistic_require_grad_tensor_number(self, module, name=None):
678
+ requires_grad_num = 0
679
+ total_num = 0
680
+ require_grad_tensor = None
681
+ for p in module.parameters():
682
+ if p.requires_grad:
683
+ requires_grad_num += 1
684
+ if require_grad_tensor is None:
685
+ require_grad_tensor = p
686
+ total_num += 1
687
+ print(
688
+ "Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)"
689
+ % (name, requires_grad_num, total_num, requires_grad_num / total_num)
690
+ )
691
+ return require_grad_tensor
692
+
693
+ def check_module_param_update(self):
694
+
695
+ if self.tracked_steps == 0:
696
+ print("Sequence2AudioMAE")
697
+ for name, module in self.named_children():
698
+ try:
699
+ require_grad_tensor = self.statistic_require_grad_tensor_number(
700
+ module, name=name
701
+ )
702
+ if require_grad_tensor is not None:
703
+ self.param_dict[name] = require_grad_tensor.clone()
704
+ else:
705
+ print("==> %s does not requires grad" % name)
706
+ except Exception as e:
707
+ print("%s does not have trainable parameters: %s" % (name, e))
708
+ continue
709
+
710
+ if self.tracked_steps % 5000 == 0:
711
+ print("Sequence2AudioMAE")
712
+ for name, module in self.named_children():
713
+ try:
714
+ require_grad_tensor = self.statistic_require_grad_tensor_number(
715
+ module, name=name
716
+ )
717
+
718
+ if require_grad_tensor is not None:
719
+ print(
720
+ "===> Param diff %s: %s; Size: %s"
721
+ % (
722
+ name,
723
+ torch.sum(
724
+ torch.abs(
725
+ self.param_dict[name] - require_grad_tensor
726
+ )
727
+ ),
728
+ require_grad_tensor.size(),
729
+ )
730
+ )
731
+ else:
732
+ print("%s does not requires grad" % name)
733
+ except Exception as e:
734
+ print("%s does not have trainable parameters: %s" % (name, e))
735
+ continue
736
+
737
+ self.tracked_steps += 1
audioldm_train/modules/audiomae/util/__pycache__/patch_embed.cpython-310.pyc ADDED
Binary file (3.42 kB). View file
 
audioldm_train/modules/audiomae/util/__pycache__/pos_embed.cpython-310.pyc ADDED
Binary file (4.33 kB). View file
 
audioldm_train/modules/audiomae/util/crop.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch
10
+
11
+ from torchvision import transforms
12
+ from torchvision.transforms import functional as F
13
+
14
+
15
+ class RandomResizedCrop(transforms.RandomResizedCrop):
16
+ """
17
+ RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
18
+ This may lead to results different with torchvision's version.
19
+ Following BYOL's TF code:
20
+ https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
21
+ """
22
+
23
+ @staticmethod
24
+ def get_params(img, scale, ratio):
25
+ width, height = F._get_image_size(img)
26
+ area = height * width
27
+
28
+ target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
29
+ log_ratio = torch.log(torch.tensor(ratio))
30
+ aspect_ratio = torch.exp(
31
+ torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
32
+ ).item()
33
+
34
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
35
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
36
+
37
+ w = min(w, width)
38
+ h = min(h, height)
39
+
40
+ i = torch.randint(0, height - h + 1, size=(1,)).item()
41
+ j = torch.randint(0, width - w + 1, size=(1,)).item()
42
+
43
+ return i, j, h, w
audioldm_train/modules/audiomae/util/datasets.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # --------------------------------------------------------
10
+
11
+ import os
12
+ import PIL
13
+
14
+ from torchvision import datasets, transforms
15
+
16
+ from timm.data import create_transform
17
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18
+
19
+
20
+ def build_dataset(is_train, args):
21
+ transform = build_transform(is_train, args)
22
+
23
+ root = os.path.join(args.data_path, "train" if is_train else "val")
24
+ dataset = datasets.ImageFolder(root, transform=transform)
25
+
26
+ print(dataset)
27
+
28
+ return dataset
29
+
30
+
31
+ def build_transform(is_train, args):
32
+ mean = IMAGENET_DEFAULT_MEAN
33
+ std = IMAGENET_DEFAULT_STD
34
+ # train transform
35
+ if is_train:
36
+ # this should always dispatch to transforms_imagenet_train
37
+ transform = create_transform(
38
+ input_size=args.input_size,
39
+ is_training=True,
40
+ color_jitter=args.color_jitter,
41
+ auto_augment=args.aa,
42
+ interpolation="bicubic",
43
+ re_prob=args.reprob,
44
+ re_mode=args.remode,
45
+ re_count=args.recount,
46
+ mean=mean,
47
+ std=std,
48
+ )
49
+ return transform
50
+
51
+ # eval transform
52
+ t = []
53
+ if args.input_size <= 224:
54
+ crop_pct = 224 / 256
55
+ else:
56
+ crop_pct = 1.0
57
+ size = int(args.input_size / crop_pct)
58
+ t.append(
59
+ transforms.Resize(
60
+ size, interpolation=PIL.Image.BICUBIC
61
+ ), # to maintain same ratio w.r.t. 224 images
62
+ )
63
+ t.append(transforms.CenterCrop(args.input_size))
64
+
65
+ t.append(transforms.ToTensor())
66
+ t.append(transforms.Normalize(mean, std))
67
+ return transforms.Compose(t)
audioldm_train/modules/audiomae/util/lars.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # LARS optimizer, implementation from MoCo v3:
8
+ # https://github.com/facebookresearch/moco-v3
9
+ # --------------------------------------------------------
10
+
11
+ import torch
12
+
13
+
14
+ class LARS(torch.optim.Optimizer):
15
+ """
16
+ LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
17
+ """
18
+
19
+ def __init__(
20
+ self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001
21
+ ):
22
+ defaults = dict(
23
+ lr=lr,
24
+ weight_decay=weight_decay,
25
+ momentum=momentum,
26
+ trust_coefficient=trust_coefficient,
27
+ )
28
+ super().__init__(params, defaults)
29
+
30
+ @torch.no_grad()
31
+ def step(self):
32
+ for g in self.param_groups:
33
+ for p in g["params"]:
34
+ dp = p.grad
35
+
36
+ if dp is None:
37
+ continue
38
+
39
+ if p.ndim > 1: # if not normalization gamma/beta or bias
40
+ dp = dp.add(p, alpha=g["weight_decay"])
41
+ param_norm = torch.norm(p)
42
+ update_norm = torch.norm(dp)
43
+ one = torch.ones_like(param_norm)
44
+ q = torch.where(
45
+ param_norm > 0.0,
46
+ torch.where(
47
+ update_norm > 0,
48
+ (g["trust_coefficient"] * param_norm / update_norm),
49
+ one,
50
+ ),
51
+ one,
52
+ )
53
+ dp = dp.mul(q)
54
+
55
+ param_state = self.state[p]
56
+ if "mu" not in param_state:
57
+ param_state["mu"] = torch.zeros_like(p)
58
+ mu = param_state["mu"]
59
+ mu.mul_(g["momentum"]).add_(dp)
60
+ p.add_(mu, alpha=-g["lr"])
audioldm_train/modules/audiomae/util/lr_decay.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # ELECTRA https://github.com/google-research/electra
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+
12
+ import json
13
+
14
+
15
+ def param_groups_lrd(
16
+ model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75
17
+ ):
18
+ """
19
+ Parameter groups for layer-wise lr decay
20
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
21
+ """
22
+ param_group_names = {}
23
+ param_groups = {}
24
+
25
+ num_layers = len(model.blocks) + 1
26
+
27
+ layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
28
+
29
+ for n, p in model.named_parameters():
30
+ if not p.requires_grad:
31
+ continue
32
+
33
+ # no decay: all 1D parameters and model specific ones
34
+ if p.ndim == 1 or n in no_weight_decay_list:
35
+ g_decay = "no_decay"
36
+ this_decay = 0.0
37
+ else:
38
+ g_decay = "decay"
39
+ this_decay = weight_decay
40
+
41
+ layer_id = get_layer_id_for_vit(n, num_layers)
42
+ group_name = "layer_%d_%s" % (layer_id, g_decay)
43
+
44
+ if group_name not in param_group_names:
45
+ this_scale = layer_scales[layer_id]
46
+
47
+ param_group_names[group_name] = {
48
+ "lr_scale": this_scale,
49
+ "weight_decay": this_decay,
50
+ "params": [],
51
+ }
52
+ param_groups[group_name] = {
53
+ "lr_scale": this_scale,
54
+ "weight_decay": this_decay,
55
+ "params": [],
56
+ }
57
+
58
+ param_group_names[group_name]["params"].append(n)
59
+ param_groups[group_name]["params"].append(p)
60
+
61
+ # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
62
+
63
+ return list(param_groups.values())
64
+
65
+
66
+ def get_layer_id_for_vit(name, num_layers):
67
+ """
68
+ Assign a parameter with its layer id
69
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
70
+ """
71
+ if name in ["cls_token", "pos_embed"]:
72
+ return 0
73
+ elif name.startswith("patch_embed"):
74
+ return 0
75
+ elif name.startswith("blocks"):
76
+ return int(name.split(".")[1]) + 1
77
+ else:
78
+ return num_layers
audioldm_train/modules/audiomae/util/lr_sched.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+
10
+ def adjust_learning_rate(optimizer, epoch, args):
11
+ """Decay the learning rate with half-cycle cosine after warmup"""
12
+ if epoch < args.warmup_epochs:
13
+ lr = args.lr * epoch / args.warmup_epochs
14
+ else:
15
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (
16
+ 1.0
17
+ + math.cos(
18
+ math.pi
19
+ * (epoch - args.warmup_epochs)
20
+ / (args.epochs - args.warmup_epochs)
21
+ )
22
+ )
23
+ for param_group in optimizer.param_groups:
24
+ if "lr_scale" in param_group:
25
+ param_group["lr"] = lr * param_group["lr_scale"]
26
+ else:
27
+ param_group["lr"] = lr
28
+ return lr
audioldm_train/modules/audiomae/util/misc.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+
12
+ import builtins
13
+ import datetime
14
+ import os
15
+ import time
16
+ from collections import defaultdict, deque
17
+ from pathlib import Path
18
+
19
+ import torch
20
+ import torch.distributed as dist
21
+ from torch._six import inf
22
+
23
+
24
+ class SmoothedValue(object):
25
+ """Track a series of values and provide access to smoothed values over a
26
+ window or the global series average.
27
+ """
28
+
29
+ def __init__(self, window_size=20, fmt=None):
30
+ if fmt is None:
31
+ fmt = "{median:.4f} ({global_avg:.4f})"
32
+ self.deque = deque(maxlen=window_size)
33
+ self.total = 0.0
34
+ self.count = 0
35
+ self.fmt = fmt
36
+
37
+ def update(self, value, n=1):
38
+ self.deque.append(value)
39
+ self.count += n
40
+ self.total += value * n
41
+
42
+ def synchronize_between_processes(self):
43
+ """
44
+ Warning: does not synchronize the deque!
45
+ """
46
+ if not is_dist_avail_and_initialized():
47
+ return
48
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
49
+ dist.barrier()
50
+ dist.all_reduce(t)
51
+ t = t.tolist()
52
+ self.count = int(t[0])
53
+ self.total = t[1]
54
+
55
+ @property
56
+ def median(self):
57
+ d = torch.tensor(list(self.deque))
58
+ return d.median().item()
59
+
60
+ @property
61
+ def avg(self):
62
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
63
+ return d.mean().item()
64
+
65
+ @property
66
+ def global_avg(self):
67
+ return self.total / self.count
68
+
69
+ @property
70
+ def max(self):
71
+ return max(self.deque)
72
+
73
+ @property
74
+ def value(self):
75
+ return self.deque[-1]
76
+
77
+ def __str__(self):
78
+ return self.fmt.format(
79
+ median=self.median,
80
+ avg=self.avg,
81
+ global_avg=self.global_avg,
82
+ max=self.max,
83
+ value=self.value,
84
+ )
85
+
86
+
87
+ class MetricLogger(object):
88
+ def __init__(self, delimiter="\t"):
89
+ self.meters = defaultdict(SmoothedValue)
90
+ self.delimiter = delimiter
91
+
92
+ def update(self, **kwargs):
93
+ for k, v in kwargs.items():
94
+ if v is None:
95
+ continue
96
+ if isinstance(v, torch.Tensor):
97
+ v = v.item()
98
+ assert isinstance(v, (float, int))
99
+ self.meters[k].update(v)
100
+
101
+ def __getattr__(self, attr):
102
+ if attr in self.meters:
103
+ return self.meters[attr]
104
+ if attr in self.__dict__:
105
+ return self.__dict__[attr]
106
+ raise AttributeError(
107
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
108
+ )
109
+
110
+ def __str__(self):
111
+ loss_str = []
112
+ for name, meter in self.meters.items():
113
+ loss_str.append("{}: {}".format(name, str(meter)))
114
+ return self.delimiter.join(loss_str)
115
+
116
+ def synchronize_between_processes(self):
117
+ for meter in self.meters.values():
118
+ meter.synchronize_between_processes()
119
+
120
+ def add_meter(self, name, meter):
121
+ self.meters[name] = meter
122
+
123
+ def log_every(self, iterable, print_freq, header=None):
124
+ i = 0
125
+ if not header:
126
+ header = ""
127
+ start_time = time.time()
128
+ end = time.time()
129
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
130
+ data_time = SmoothedValue(fmt="{avg:.4f}")
131
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
132
+ log_msg = [
133
+ header,
134
+ "[{0" + space_fmt + "}/{1}]",
135
+ "eta: {eta}",
136
+ "{meters}",
137
+ "time: {time}",
138
+ "data: {data}",
139
+ ]
140
+ if torch.cuda.is_available():
141
+ log_msg.append("max mem: {memory:.0f}")
142
+ log_msg = self.delimiter.join(log_msg)
143
+ MB = 1024.0 * 1024.0
144
+ for obj in iterable:
145
+ data_time.update(time.time() - end)
146
+ yield obj
147
+ iter_time.update(time.time() - end)
148
+ if i % print_freq == 0 or i == len(iterable) - 1:
149
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
150
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
151
+ if torch.cuda.is_available():
152
+ print(
153
+ log_msg.format(
154
+ i,
155
+ len(iterable),
156
+ eta=eta_string,
157
+ meters=str(self),
158
+ time=str(iter_time),
159
+ data=str(data_time),
160
+ memory=torch.cuda.max_memory_allocated() / MB,
161
+ )
162
+ )
163
+ else:
164
+ print(
165
+ log_msg.format(
166
+ i,
167
+ len(iterable),
168
+ eta=eta_string,
169
+ meters=str(self),
170
+ time=str(iter_time),
171
+ data=str(data_time),
172
+ )
173
+ )
174
+ i += 1
175
+ end = time.time()
176
+ total_time = time.time() - start_time
177
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
178
+ print(
179
+ "{} Total time: {} ({:.4f} s / it)".format(
180
+ header, total_time_str, total_time / len(iterable)
181
+ )
182
+ )
183
+
184
+
185
+ def setup_for_distributed(is_master):
186
+ """
187
+ This function disables printing when not in master process
188
+ """
189
+ builtin_print = builtins.print
190
+
191
+ def print(*args, **kwargs):
192
+ force = kwargs.pop("force", False)
193
+ force = force or (get_world_size() > 8)
194
+ if is_master or force:
195
+ now = datetime.datetime.now().time()
196
+ builtin_print("[{}] ".format(now), end="") # print with time stamp
197
+ builtin_print(*args, **kwargs)
198
+
199
+ builtins.print = print
200
+
201
+
202
+ def is_dist_avail_and_initialized():
203
+ if not dist.is_available():
204
+ return False
205
+ if not dist.is_initialized():
206
+ return False
207
+ return True
208
+
209
+
210
+ def get_world_size():
211
+ if not is_dist_avail_and_initialized():
212
+ return 1
213
+ return dist.get_world_size()
214
+
215
+
216
+ def get_rank():
217
+ if not is_dist_avail_and_initialized():
218
+ return 0
219
+ return dist.get_rank()
220
+
221
+
222
+ def is_main_process():
223
+ return get_rank() == 0
224
+
225
+
226
+ def save_on_master(*args, **kwargs):
227
+ if is_main_process():
228
+ torch.save(*args, **kwargs)
229
+
230
+
231
+ def init_distributed_mode(args):
232
+ if args.dist_on_itp:
233
+ args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
234
+ args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
235
+ args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
236
+ args.dist_url = "tcp://%s:%s" % (
237
+ os.environ["MASTER_ADDR"],
238
+ os.environ["MASTER_PORT"],
239
+ )
240
+ os.environ["LOCAL_RANK"] = str(args.gpu)
241
+ os.environ["RANK"] = str(args.rank)
242
+ os.environ["WORLD_SIZE"] = str(args.world_size)
243
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
244
+ elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
245
+ args.rank = int(os.environ["RANK"])
246
+ args.world_size = int(os.environ["WORLD_SIZE"])
247
+ args.gpu = int(os.environ["LOCAL_RANK"])
248
+ elif "SLURM_PROCID" in os.environ:
249
+ args.rank = int(os.environ["SLURM_PROCID"])
250
+ args.gpu = args.rank % torch.cuda.device_count()
251
+ else:
252
+ print("Not using distributed mode")
253
+ setup_for_distributed(is_master=True) # hack
254
+ args.distributed = False
255
+ return
256
+
257
+ args.distributed = True
258
+
259
+ torch.cuda.set_device(args.gpu)
260
+ args.dist_backend = "nccl"
261
+ print(
262
+ "| distributed init (rank {}): {}, gpu {}".format(
263
+ args.rank, args.dist_url, args.gpu
264
+ ),
265
+ flush=True,
266
+ )
267
+ torch.distributed.init_process_group(
268
+ backend=args.dist_backend,
269
+ init_method=args.dist_url,
270
+ world_size=args.world_size,
271
+ rank=args.rank,
272
+ )
273
+ torch.distributed.barrier()
274
+ setup_for_distributed(args.rank == 0)
275
+
276
+
277
+ class NativeScalerWithGradNormCount:
278
+ state_dict_key = "amp_scaler"
279
+
280
+ def __init__(self):
281
+ self._scaler = torch.cuda.amp.GradScaler()
282
+
283
+ def __call__(
284
+ self,
285
+ loss,
286
+ optimizer,
287
+ clip_grad=None,
288
+ parameters=None,
289
+ create_graph=False,
290
+ update_grad=True,
291
+ ):
292
+ self._scaler.scale(loss).backward(create_graph=create_graph)
293
+ if update_grad:
294
+ if clip_grad is not None:
295
+ assert parameters is not None
296
+ self._scaler.unscale_(
297
+ optimizer
298
+ ) # unscale the gradients of optimizer's assigned params in-place
299
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
300
+ else:
301
+ self._scaler.unscale_(optimizer)
302
+ norm = get_grad_norm_(parameters)
303
+ self._scaler.step(optimizer)
304
+ self._scaler.update()
305
+ else:
306
+ norm = None
307
+ return norm
308
+
309
+ def state_dict(self):
310
+ return self._scaler.state_dict()
311
+
312
+ def load_state_dict(self, state_dict):
313
+ self._scaler.load_state_dict(state_dict)
314
+
315
+
316
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
317
+ if isinstance(parameters, torch.Tensor):
318
+ parameters = [parameters]
319
+ parameters = [p for p in parameters if p.grad is not None]
320
+ norm_type = float(norm_type)
321
+ if len(parameters) == 0:
322
+ return torch.tensor(0.0)
323
+ device = parameters[0].grad.device
324
+ if norm_type == inf:
325
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
326
+ else:
327
+ total_norm = torch.norm(
328
+ torch.stack(
329
+ [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
330
+ ),
331
+ norm_type,
332
+ )
333
+ return total_norm
334
+
335
+
336
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
337
+ output_dir = Path(args.output_dir)
338
+ epoch_name = str(epoch)
339
+ if loss_scaler is not None:
340
+ checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)]
341
+ for checkpoint_path in checkpoint_paths:
342
+ to_save = {
343
+ "model": model_without_ddp.state_dict(),
344
+ "optimizer": optimizer.state_dict(),
345
+ "epoch": epoch,
346
+ "scaler": loss_scaler.state_dict(),
347
+ "args": args,
348
+ }
349
+
350
+ save_on_master(to_save, checkpoint_path)
351
+ else:
352
+ client_state = {"epoch": epoch}
353
+ model.save_checkpoint(
354
+ save_dir=args.output_dir,
355
+ tag="checkpoint-%s" % epoch_name,
356
+ client_state=client_state,
357
+ )
358
+
359
+
360
+ def load_model(args, model_without_ddp, optimizer, loss_scaler):
361
+ if args.resume:
362
+ if args.resume.startswith("https"):
363
+ checkpoint = torch.hub.load_state_dict_from_url(
364
+ args.resume, map_location="cpu", check_hash=True
365
+ )
366
+ else:
367
+ checkpoint = torch.load(args.resume, map_location="cpu")
368
+ model_without_ddp.load_state_dict(checkpoint["model"])
369
+ print("Resume checkpoint %s" % args.resume)
370
+ if (
371
+ "optimizer" in checkpoint
372
+ and "epoch" in checkpoint
373
+ and not (hasattr(args, "eval") and args.eval)
374
+ ):
375
+ optimizer.load_state_dict(checkpoint["optimizer"])
376
+ args.start_epoch = checkpoint["epoch"] + 1
377
+ if "scaler" in checkpoint:
378
+ loss_scaler.load_state_dict(checkpoint["scaler"])
379
+ print("With optim & sched!")
380
+
381
+
382
+ def all_reduce_mean(x):
383
+ world_size = get_world_size()
384
+ if world_size > 1:
385
+ x_reduce = torch.tensor(x).cuda()
386
+ dist.all_reduce(x_reduce)
387
+ x_reduce /= world_size
388
+ return x_reduce.item()
389
+ else:
390
+ return x
391
+
392
+
393
+ # utils
394
+ @torch.no_grad()
395
+ def concat_all_gather(tensor):
396
+ """
397
+ Performs all_gather operation on the provided tensors.
398
+ *** Warning ***: torch.distributed.all_gather has no gradient.
399
+ """
400
+ tensors_gather = [
401
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
402
+ ]
403
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
404
+
405
+ output = torch.cat(tensors_gather, dim=0)
406
+ return output
407
+
408
+
409
+ def merge_vmae_to_avmae(avmae_state_dict, vmae_ckpt):
410
+ # keys_to_copy=['pos_embed','patch_embed']
411
+ # replaced=0
412
+
413
+ vmae_ckpt["cls_token"] = vmae_ckpt["cls_token_v"]
414
+ vmae_ckpt["mask_token"] = vmae_ckpt["mask_token_v"]
415
+
416
+ # pos_emb % not trainable, use default
417
+ pos_embed_v = vmae_ckpt["pos_embed_v"] # 1,589,768
418
+ pos_embed = pos_embed_v[:, 1:, :] # 1,588,768
419
+ cls_embed = pos_embed_v[:, 0, :].unsqueeze(1)
420
+ pos_embed = pos_embed.reshape(1, 2, 14, 14, 768).sum(dim=1) # 1, 14, 14, 768
421
+ print("Position interpolate from 14,14 to 64,8")
422
+ pos_embed = pos_embed.permute(0, 3, 1, 2) # 1, 14,14,768 -> 1,768,14,14
423
+ pos_embed = torch.nn.functional.interpolate(
424
+ pos_embed, size=(64, 8), mode="bicubic", align_corners=False
425
+ )
426
+ pos_embed = pos_embed.permute(0, 2, 3, 1).flatten(
427
+ 1, 2
428
+ ) # 1, 14, 14, 768 => 1, 196,768
429
+ pos_embed = torch.cat((cls_embed, pos_embed), dim=1)
430
+ assert vmae_ckpt["pos_embed"].shape == pos_embed.shape
431
+ vmae_ckpt["pos_embed"] = pos_embed
432
+ # patch_emb
433
+ # aggregate 3 channels in video-rgb ckpt to 1 channel for audio
434
+ v_weight = vmae_ckpt["patch_embed_v.proj.weight"] # 768,3,2,16,16
435
+ new_proj_weight = torch.nn.Parameter(v_weight.sum(dim=2).sum(dim=1).unsqueeze(1))
436
+ assert new_proj_weight.shape == vmae_ckpt["patch_embed.proj.weight"].shape
437
+ vmae_ckpt["patch_embed.proj.weight"] = new_proj_weight
438
+ vmae_ckpt["patch_embed.proj.bias"] = vmae_ckpt["patch_embed_v.proj.bias"]
439
+
440
+ # hack
441
+ vmae_ckpt["norm.weight"] = vmae_ckpt["norm_v.weight"]
442
+ vmae_ckpt["norm.bias"] = vmae_ckpt["norm_v.bias"]
443
+
444
+ # replace transformer encoder
445
+ for k, v in vmae_ckpt.items():
446
+ if k.startswith("blocks."):
447
+ kk = k.replace("blocks.", "blocks_v.")
448
+ vmae_ckpt[k] = vmae_ckpt[kk]
449
+ elif k.startswith("blocks_v."):
450
+ pass
451
+ else:
452
+ print(k)
453
+ pass
454
+ print(k)
audioldm_train/modules/audiomae/util/patch_embed.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from timm.models.layers import to_2tuple
4
+
5
+
6
+ class PatchEmbed_org(nn.Module):
7
+ """Image to Patch Embedding"""
8
+
9
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
10
+ super().__init__()
11
+ img_size = to_2tuple(img_size)
12
+ patch_size = to_2tuple(patch_size)
13
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
14
+ self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
15
+ self.img_size = img_size
16
+ self.patch_size = patch_size
17
+ self.num_patches = num_patches
18
+
19
+ self.proj = nn.Conv2d(
20
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
21
+ )
22
+
23
+ def forward(self, x):
24
+ B, C, H, W = x.shape
25
+ # FIXME look at relaxing size constraints
26
+ # assert H == self.img_size[0] and W == self.img_size[1], \
27
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
28
+ x = self.proj(x)
29
+ y = x.flatten(2).transpose(1, 2)
30
+ return y
31
+
32
+
33
+ class PatchEmbed_new(nn.Module):
34
+ """Flexible Image to Patch Embedding"""
35
+
36
+ def __init__(
37
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
38
+ ):
39
+ super().__init__()
40
+ img_size = to_2tuple(img_size)
41
+ patch_size = to_2tuple(patch_size)
42
+ stride = to_2tuple(stride)
43
+
44
+ self.img_size = img_size
45
+ self.patch_size = patch_size
46
+
47
+ self.proj = nn.Conv2d(
48
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride
49
+ ) # with overlapped patches
50
+ # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
51
+
52
+ # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
53
+ # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
54
+ _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
55
+ self.patch_hw = (h, w)
56
+ self.num_patches = h * w
57
+
58
+ def get_output_shape(self, img_size):
59
+ # todo: don't be lazy..
60
+ return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
61
+
62
+ def forward(self, x):
63
+ B, C, H, W = x.shape
64
+ # FIXME look at relaxing size constraints
65
+ # assert H == self.img_size[0] and W == self.img_size[1], \
66
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
67
+ # x = self.proj(x).flatten(2).transpose(1, 2)
68
+ x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12
69
+ x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212
70
+ x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768
71
+ return x
72
+
73
+
74
+ class PatchEmbed3D_new(nn.Module):
75
+ """Flexible Image to Patch Embedding"""
76
+
77
+ def __init__(
78
+ self,
79
+ video_size=(16, 224, 224),
80
+ patch_size=(2, 16, 16),
81
+ in_chans=3,
82
+ embed_dim=768,
83
+ stride=(2, 16, 16),
84
+ ):
85
+ super().__init__()
86
+
87
+ self.video_size = video_size
88
+ self.patch_size = patch_size
89
+ self.in_chans = in_chans
90
+
91
+ self.proj = nn.Conv3d(
92
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride
93
+ )
94
+ _, _, t, h, w = self.get_output_shape(video_size) # n, emb_dim, h, w
95
+ self.patch_thw = (t, h, w)
96
+ self.num_patches = t * h * w
97
+
98
+ def get_output_shape(self, video_size):
99
+ # todo: don't be lazy..
100
+ return self.proj(
101
+ torch.randn(1, self.in_chans, video_size[0], video_size[1], video_size[2])
102
+ ).shape
103
+
104
+ def forward(self, x):
105
+ B, C, T, H, W = x.shape
106
+ x = self.proj(x) # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14
107
+ x = x.flatten(2) # 32, 768, 1568
108
+ x = x.transpose(1, 2) # 32, 768, 1568 -> 32, 1568, 768
109
+ return x
110
+
111
+
112
+ if __name__ == "__main__":
113
+ # patch_emb = PatchEmbed_new(img_size=224, patch_size=16, in_chans=1, embed_dim=64, stride=(16,16))
114
+ # input = torch.rand(8,1,1024,128)
115
+ # output = patch_emb(input)
116
+ # print(output.shape) # (8,512,64)
117
+
118
+ patch_emb = PatchEmbed3D_new(
119
+ video_size=(6, 224, 224),
120
+ patch_size=(2, 16, 16),
121
+ in_chans=3,
122
+ embed_dim=768,
123
+ stride=(2, 16, 16),
124
+ )
125
+ input = torch.rand(8, 3, 6, 224, 224)
126
+ output = patch_emb(input)
127
+ print(output.shape) # (8,64)
audioldm_train/modules/audiomae/util/pos_embed.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+
14
+ # --------------------------------------------------------
15
+ # 2D sine-cosine position embedding
16
+ # References:
17
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
19
+ # --------------------------------------------------------
20
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21
+ """
22
+ grid_size: int of the grid height and width
23
+ return:
24
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25
+ """
26
+ grid_h = np.arange(grid_size, dtype=np.float32)
27
+ grid_w = np.arange(grid_size, dtype=np.float32)
28
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
29
+ grid = np.stack(grid, axis=0)
30
+
31
+ grid = grid.reshape([2, 1, grid_size, grid_size])
32
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33
+ if cls_token:
34
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35
+ return pos_embed
36
+
37
+
38
+ def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
39
+ """
40
+ grid_size: int of the grid height and width
41
+ return:
42
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
43
+ """
44
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
45
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
46
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
47
+ grid = np.stack(grid, axis=0)
48
+
49
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
50
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
51
+ if cls_token:
52
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
53
+ return pos_embed
54
+
55
+
56
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
57
+ assert embed_dim % 2 == 0
58
+
59
+ # use half of dimensions to encode grid_h
60
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
61
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
62
+
63
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
64
+ return emb
65
+
66
+
67
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
68
+ """
69
+ embed_dim: output dimension for each position
70
+ pos: a list of positions to be encoded: size (M,)
71
+ out: (M, D)
72
+ """
73
+ assert embed_dim % 2 == 0
74
+ # omega = np.arange(embed_dim // 2, dtype=np.float)
75
+ omega = np.arange(embed_dim // 2, dtype=float)
76
+ omega /= embed_dim / 2.0
77
+ omega = 1.0 / 10000**omega # (D/2,)
78
+
79
+ pos = pos.reshape(-1) # (M,)
80
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
81
+
82
+ emb_sin = np.sin(out) # (M, D/2)
83
+ emb_cos = np.cos(out) # (M, D/2)
84
+
85
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
86
+ return emb
87
+
88
+
89
+ # --------------------------------------------------------
90
+ # Interpolate position embeddings for high-resolution
91
+ # References:
92
+ # DeiT: https://github.com/facebookresearch/deit
93
+ # --------------------------------------------------------
94
+ def interpolate_pos_embed(model, checkpoint_model):
95
+ if "pos_embed" in checkpoint_model:
96
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
97
+ embedding_size = pos_embed_checkpoint.shape[-1]
98
+ num_patches = model.patch_embed.num_patches
99
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
100
+ # height (== width) for the checkpoint position embedding
101
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
102
+ # height (== width) for the new position embedding
103
+ new_size = int(num_patches**0.5)
104
+ # class_token and dist_token are kept unchanged
105
+ if orig_size != new_size:
106
+ print(
107
+ "Position interpolate from %dx%d to %dx%d"
108
+ % (orig_size, orig_size, new_size, new_size)
109
+ )
110
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
111
+ # only the position tokens are interpolated
112
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
113
+ pos_tokens = pos_tokens.reshape(
114
+ -1, orig_size, orig_size, embedding_size
115
+ ).permute(0, 3, 1, 2)
116
+ pos_tokens = torch.nn.functional.interpolate(
117
+ pos_tokens,
118
+ size=(new_size, new_size),
119
+ mode="bicubic",
120
+ align_corners=False,
121
+ )
122
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
123
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
124
+ checkpoint_model["pos_embed"] = new_pos_embed
125
+
126
+
127
+ def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size):
128
+ if "pos_embed" in checkpoint_model:
129
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
130
+ embedding_size = pos_embed_checkpoint.shape[-1]
131
+ num_patches = model.patch_embed.num_patches
132
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
133
+ # height (== width) for the checkpoint position embedding
134
+ # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
135
+ # height (== width) for the new position embedding
136
+ # new_size = int(num_patches ** 0.5)
137
+ # class_token and dist_token are kept unchanged
138
+ if orig_size != new_size:
139
+ print(
140
+ "Position interpolate from %dx%d to %dx%d"
141
+ % (orig_size[0], orig_size[1], new_size[0], new_size[1])
142
+ )
143
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
144
+ # only the position tokens are interpolated
145
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
146
+ pos_tokens = pos_tokens.reshape(
147
+ -1, orig_size[0], orig_size[1], embedding_size
148
+ ).permute(0, 3, 1, 2)
149
+ pos_tokens = torch.nn.functional.interpolate(
150
+ pos_tokens,
151
+ size=(new_size[0], new_size[1]),
152
+ mode="bicubic",
153
+ align_corners=False,
154
+ )
155
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
156
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
157
+ checkpoint_model["pos_embed"] = new_pos_embed
158
+
159
+
160
+ def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size):
161
+ if "pos_embed" in checkpoint_model:
162
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
163
+ embedding_size = pos_embed_checkpoint.shape[-1]
164
+ num_patches = model.patch_embed.num_patches
165
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
166
+ if orig_size != new_size:
167
+ print(
168
+ "Position interpolate from %dx%d to %dx%d"
169
+ % (orig_size[0], orig_size[1], new_size[0], new_size[1])
170
+ )
171
+ # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
172
+ # only the position tokens are interpolated
173
+ cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1)
174
+ pos_tokens = pos_embed_checkpoint[:, 1:, :] # remove
175
+ pos_tokens = pos_tokens.reshape(
176
+ -1, orig_size[0], orig_size[1], embedding_size
177
+ ) # .permute(0, 3, 1, 2)
178
+ # pos_tokens = torch.nn.functional.interpolate(
179
+ # pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False)
180
+
181
+ # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
182
+ pos_tokens = pos_tokens[:, :, : new_size[1], :] # assume only time diff
183
+ pos_tokens = pos_tokens.flatten(1, 2)
184
+ new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1)
185
+ checkpoint_model["pos_embed"] = new_pos_embed
186
+
187
+
188
+ def interpolate_patch_embed_audio(
189
+ model,
190
+ checkpoint_model,
191
+ orig_channel,
192
+ new_channel=1,
193
+ kernel_size=(16, 16),
194
+ stride=(16, 16),
195
+ padding=(0, 0),
196
+ ):
197
+ if orig_channel != new_channel:
198
+ if "patch_embed.proj.weight" in checkpoint_model:
199
+ # aggregate 3 channels in rgb ckpt to 1 channel for audio
200
+ new_proj_weight = torch.nn.Parameter(
201
+ torch.sum(checkpoint_model["patch_embed.proj.weight"], dim=1).unsqueeze(
202
+ 1
203
+ )
204
+ )
205
+ checkpoint_model["patch_embed.proj.weight"] = new_proj_weight
audioldm_train/modules/audiomae/util/stat.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy import stats
3
+ from sklearn import metrics
4
+ import torch
5
+
6
+
7
+ def d_prime(auc):
8
+ standard_normal = stats.norm()
9
+ d_prime = standard_normal.ppf(auc) * np.sqrt(2.0)
10
+ return d_prime
11
+
12
+
13
+ @torch.no_grad()
14
+ def concat_all_gather(tensor):
15
+ """
16
+ Performs all_gather operation on the provided tensors.
17
+ *** Warning ***: torch.distributed.all_gather has no gradient.
18
+ """
19
+ tensors_gather = [
20
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
21
+ ]
22
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
23
+
24
+ output = torch.cat(tensors_gather, dim=0)
25
+ return output
26
+
27
+
28
+ def calculate_stats(output, target):
29
+ """Calculate statistics including mAP, AUC, etc.
30
+
31
+ Args:
32
+ output: 2d array, (samples_num, classes_num)
33
+ target: 2d array, (samples_num, classes_num)
34
+
35
+ Returns:
36
+ stats: list of statistic of each class.
37
+ """
38
+
39
+ classes_num = target.shape[-1]
40
+ stats = []
41
+
42
+ # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet
43
+ acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1))
44
+
45
+ # Class-wise statistics
46
+ for k in range(classes_num):
47
+
48
+ # Average precision
49
+ avg_precision = metrics.average_precision_score(
50
+ target[:, k], output[:, k], average=None
51
+ )
52
+
53
+ # AUC
54
+ # auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None)
55
+
56
+ # Precisions, recalls
57
+ (precisions, recalls, thresholds) = metrics.precision_recall_curve(
58
+ target[:, k], output[:, k]
59
+ )
60
+
61
+ # FPR, TPR
62
+ (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k])
63
+
64
+ save_every_steps = 1000 # Sample statistics to reduce size
65
+ dict = {
66
+ "precisions": precisions[0::save_every_steps],
67
+ "recalls": recalls[0::save_every_steps],
68
+ "AP": avg_precision,
69
+ "fpr": fpr[0::save_every_steps],
70
+ "fnr": 1.0 - tpr[0::save_every_steps],
71
+ # 'auc': auc,
72
+ # note acc is not class-wise, this is just to keep consistent with other metrics
73
+ "acc": acc,
74
+ }
75
+ stats.append(dict)
76
+
77
+ return stats
audioldm_train/modules/clap/__init__.py ADDED
File without changes
audioldm_train/modules/clap/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (148 Bytes). View file
 
audioldm_train/modules/clap/open_clip/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .factory import (
2
+ list_models,
3
+ create_model,
4
+ create_model_and_transforms,
5
+ add_model_config,
6
+ )
7
+ from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
8
+ from .model import (
9
+ CLAP,
10
+ CLAPTextCfg,
11
+ CLAPVisionCfg,
12
+ CLAPAudioCfp,
13
+ convert_weights_to_fp16,
14
+ trace_model,
15
+ )
16
+ from .openai import load_openai_model, list_openai_models
17
+ from .pretrained import (
18
+ list_pretrained,
19
+ list_pretrained_tag_models,
20
+ list_pretrained_model_tags,
21
+ get_pretrained_url,
22
+ download_pretrained,
23
+ )
24
+ from .tokenizer import SimpleTokenizer, tokenize
25
+ from .transform import image_transform
audioldm_train/modules/clap/open_clip/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (954 Bytes). View file
 
audioldm_train/modules/clap/open_clip/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.01 kB). View file
 
audioldm_train/modules/clap/open_clip/__pycache__/factory.cpython-310.pyc ADDED
Binary file (6.79 kB). View file
 
audioldm_train/modules/clap/open_clip/__pycache__/factory.cpython-38.pyc ADDED
Binary file (6.82 kB). View file