lmzjms commited on
Commit
8121fee
·
1 Parent(s): 6525056

Upload 46 files

Browse files
Files changed (46) hide show
  1. audio_to_text/__init__.py +0 -0
  2. audio_to_text/__pycache__/__init__.cpython-38.pyc +0 -0
  3. audio_to_text/__pycache__/inference_waveform.cpython-38.pyc +0 -0
  4. audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/config.yaml +23 -0
  5. audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/swa.pth +3 -0
  6. audio_to_text/captioning/__init__.py +0 -0
  7. audio_to_text/captioning/__pycache__/__init__.cpython-38.pyc +0 -0
  8. audio_to_text/captioning/models/__init__.py +3 -0
  9. audio_to_text/captioning/models/__pycache__/__init__.cpython-38.pyc +0 -0
  10. audio_to_text/captioning/models/__pycache__/attn_model.cpython-38.pyc +0 -0
  11. audio_to_text/captioning/models/__pycache__/base_model.cpython-38.pyc +0 -0
  12. audio_to_text/captioning/models/__pycache__/decoder.cpython-38.pyc +0 -0
  13. audio_to_text/captioning/models/__pycache__/encoder.cpython-38.pyc +0 -0
  14. audio_to_text/captioning/models/__pycache__/fc_model.cpython-38.pyc +0 -0
  15. audio_to_text/captioning/models/__pycache__/rl_model.cpython-38.pyc +0 -0
  16. audio_to_text/captioning/models/__pycache__/style_model.cpython-38.pyc +0 -0
  17. audio_to_text/captioning/models/__pycache__/transformer_model.cpython-38.pyc +0 -0
  18. audio_to_text/captioning/models/__pycache__/utils.cpython-38.pyc +0 -0
  19. audio_to_text/captioning/models/base_model.py +500 -0
  20. audio_to_text/captioning/models/decoder.py +746 -0
  21. audio_to_text/captioning/models/encoder.py +686 -0
  22. audio_to_text/captioning/models/transformer_model.py +265 -0
  23. audio_to_text/captioning/models/utils.py +132 -0
  24. audio_to_text/captioning/utils/README.md +19 -0
  25. audio_to_text/captioning/utils/__init__.py +0 -0
  26. audio_to_text/captioning/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  27. audio_to_text/captioning/utils/__pycache__/train_util.cpython-38.pyc +0 -0
  28. audio_to_text/captioning/utils/bert/create_sent_embedding.py +89 -0
  29. audio_to_text/captioning/utils/bert/create_word_embedding.py +34 -0
  30. audio_to_text/captioning/utils/build_vocab.py +153 -0
  31. audio_to_text/captioning/utils/build_vocab_ltp.py +150 -0
  32. audio_to_text/captioning/utils/build_vocab_spacy.py +152 -0
  33. audio_to_text/captioning/utils/eval_round_robin.py +182 -0
  34. audio_to_text/captioning/utils/fasttext/create_word_embedding.py +50 -0
  35. audio_to_text/captioning/utils/lr_scheduler.py +128 -0
  36. audio_to_text/captioning/utils/model_eval_diff.py +110 -0
  37. audio_to_text/captioning/utils/predict_nn.py +49 -0
  38. audio_to_text/captioning/utils/remove_optimizer.py +18 -0
  39. audio_to_text/captioning/utils/report_results.py +37 -0
  40. audio_to_text/captioning/utils/tokenize_caption.py +86 -0
  41. audio_to_text/captioning/utils/train_util.py +178 -0
  42. audio_to_text/captioning/utils/word2vec/create_word_embedding.py +67 -0
  43. audio_to_text/clotho_cntrstv_cnn14rnn_trm/config.yaml +22 -0
  44. audio_to_text/clotho_cntrstv_cnn14rnn_trm/swa.pth +3 -0
  45. audio_to_text/inference_waveform.py +102 -0
  46. audio_to_text/pretrained_feature_extractors/contrastive_pretrain_cnn14_bertm.pth +3 -0
audio_to_text/__init__.py ADDED
File without changes
audio_to_text/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (157 Bytes). View file
 
audio_to_text/__pycache__/inference_waveform.cpython-38.pyc ADDED
Binary file (3.01 kB). View file
 
audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/config.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ encoder:
3
+ type: Cnn14RnnEncoder
4
+ args:
5
+ sample_rate: 32000
6
+ pretrained: ./audio_to_text/pretrained_feature_extractors/contrastive_pretrain_cnn14_bertm.pth
7
+ freeze_cnn: True
8
+ freeze_cnn_bn: True
9
+ bidirectional: True
10
+ dropout: 0.5
11
+ hidden_size: 256
12
+ num_layers: 3
13
+ decoder:
14
+ type: TransformerDecoder
15
+ args:
16
+ attn_emb_dim: 512
17
+ dropout: 0.2
18
+ emb_dim: 256
19
+ fc_emb_dim: 512
20
+ nlayers: 2
21
+ type: TransformerModel
22
+ args: {}
23
+
audio_to_text/audiocaps_cntrstv_cnn14rnn_trm/swa.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d22099e1025baae0f32ce09ec02c3d5fea001e295512fbf8754b5c66db21b0ec
3
+ size 43027289
audio_to_text/captioning/__init__.py ADDED
File without changes
audio_to_text/captioning/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (168 Bytes). View file
 
audio_to_text/captioning/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .base_model import *
2
+ from .transformer_model import *
3
+
audio_to_text/captioning/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (231 Bytes). View file
 
audio_to_text/captioning/models/__pycache__/attn_model.cpython-38.pyc ADDED
Binary file (7.73 kB). View file
 
audio_to_text/captioning/models/__pycache__/base_model.cpython-38.pyc ADDED
Binary file (15.8 kB). View file
 
audio_to_text/captioning/models/__pycache__/decoder.cpython-38.pyc ADDED
Binary file (19.1 kB). View file
 
audio_to_text/captioning/models/__pycache__/encoder.cpython-38.pyc ADDED
Binary file (19.4 kB). View file
 
audio_to_text/captioning/models/__pycache__/fc_model.cpython-38.pyc ADDED
Binary file (3.5 kB). View file
 
audio_to_text/captioning/models/__pycache__/rl_model.cpython-38.pyc ADDED
Binary file (2.19 kB). View file
 
audio_to_text/captioning/models/__pycache__/style_model.cpython-38.pyc ADDED
Binary file (3.4 kB). View file
 
audio_to_text/captioning/models/__pycache__/transformer_model.cpython-38.pyc ADDED
Binary file (7.6 kB). View file
 
audio_to_text/captioning/models/__pycache__/utils.cpython-38.pyc ADDED
Binary file (4.16 kB). View file
 
audio_to_text/captioning/models/base_model.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .utils import mean_with_lens, repeat_tensor
9
+
10
+
11
+ class CaptionModel(nn.Module):
12
+ """
13
+ Encoder-decoder captioning model.
14
+ """
15
+
16
+ pad_idx = 0
17
+ start_idx = 1
18
+ end_idx = 2
19
+ max_length = 20
20
+
21
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
22
+ super().__init__()
23
+ self.encoder = encoder
24
+ self.decoder = decoder
25
+ self.vocab_size = decoder.vocab_size
26
+ self.train_forward_keys = ["cap", "cap_len", "ss_ratio"]
27
+ self.inference_forward_keys = ["sample_method", "max_length", "temp"]
28
+ freeze_encoder = kwargs.get("freeze_encoder", False)
29
+ if freeze_encoder:
30
+ for param in self.encoder.parameters():
31
+ param.requires_grad = False
32
+ self.check_decoder_compatibility()
33
+
34
+ def check_decoder_compatibility(self):
35
+ compatible_decoders = [x.__class__.__name__ for x in self.compatible_decoders]
36
+ assert isinstance(self.decoder, self.compatible_decoders), \
37
+ f"{self.decoder.__class__.__name__} is incompatible with " \
38
+ f"{self.__class__.__name__}, please use decoder in {compatible_decoders} "
39
+
40
+ @classmethod
41
+ def set_index(cls, start_idx, end_idx):
42
+ cls.start_idx = start_idx
43
+ cls.end_idx = end_idx
44
+
45
+ def forward(self, input_dict: Dict):
46
+ """
47
+ input_dict: {
48
+ (required)
49
+ mode: train/inference,
50
+ spec,
51
+ spec_len,
52
+ fc,
53
+ attn,
54
+ attn_len,
55
+ [sample_method: greedy],
56
+ [temp: 1.0] (in case of no teacher forcing)
57
+
58
+ (optional, mode=train)
59
+ cap,
60
+ cap_len,
61
+ ss_ratio,
62
+
63
+ (optional, mode=inference)
64
+ sample_method: greedy/beam,
65
+ max_length,
66
+ temp,
67
+ beam_size (optional, sample_method=beam),
68
+ n_best (optional, sample_method=beam),
69
+ }
70
+ """
71
+ # encoder_input_keys = ["spec", "spec_len", "fc", "attn", "attn_len"]
72
+ # encoder_input = { key: input_dict[key] for key in encoder_input_keys }
73
+ encoder_output_dict = self.encoder(input_dict)
74
+ if input_dict["mode"] == "train":
75
+ forward_dict = {
76
+ "mode": "train", "sample_method": "greedy", "temp": 1.0
77
+ }
78
+ for key in self.train_forward_keys:
79
+ forward_dict[key] = input_dict[key]
80
+ forward_dict.update(encoder_output_dict)
81
+ output = self.train_forward(forward_dict)
82
+ elif input_dict["mode"] == "inference":
83
+ forward_dict = {"mode": "inference"}
84
+ default_args = { "sample_method": "greedy", "max_length": self.max_length, "temp": 1.0 }
85
+ for key in self.inference_forward_keys:
86
+ if key in input_dict:
87
+ forward_dict[key] = input_dict[key]
88
+ else:
89
+ forward_dict[key] = default_args[key]
90
+
91
+ if forward_dict["sample_method"] == "beam":
92
+ forward_dict["beam_size"] = input_dict.get("beam_size", 3)
93
+ forward_dict["n_best"] = input_dict.get("n_best", False)
94
+ forward_dict["n_best_size"] = input_dict.get("n_best_size", forward_dict["beam_size"])
95
+ elif forward_dict["sample_method"] == "dbs":
96
+ forward_dict["beam_size"] = input_dict.get("beam_size", 6)
97
+ forward_dict["group_size"] = input_dict.get("group_size", 3)
98
+ forward_dict["diversity_lambda"] = input_dict.get("diversity_lambda", 0.5)
99
+ forward_dict["group_nbest"] = input_dict.get("group_nbest", True)
100
+
101
+ forward_dict.update(encoder_output_dict)
102
+ output = self.inference_forward(forward_dict)
103
+ else:
104
+ raise Exception("mode should be either 'train' or 'inference'")
105
+
106
+ return output
107
+
108
+ def prepare_output(self, input_dict):
109
+ output = {}
110
+ batch_size = input_dict["fc_emb"].size(0)
111
+ if input_dict["mode"] == "train":
112
+ max_length = input_dict["cap"].size(1) - 1
113
+ elif input_dict["mode"] == "inference":
114
+ max_length = input_dict["max_length"]
115
+ else:
116
+ raise Exception("mode should be either 'train' or 'inference'")
117
+ device = input_dict["fc_emb"].device
118
+ output["seq"] = torch.full((batch_size, max_length), self.end_idx,
119
+ dtype=torch.long)
120
+ output["logit"] = torch.empty(batch_size, max_length,
121
+ self.vocab_size).to(device)
122
+ output["sampled_logprob"] = torch.zeros(batch_size, max_length)
123
+ output["embed"] = torch.empty(batch_size, max_length,
124
+ self.decoder.d_model).to(device)
125
+ return output
126
+
127
+ def train_forward(self, input_dict):
128
+ if input_dict["ss_ratio"] != 1: # scheduled sampling training
129
+ input_dict["mode"] = "train"
130
+ return self.stepwise_forward(input_dict)
131
+ output = self.seq_forward(input_dict)
132
+ self.train_process(output, input_dict)
133
+ return output
134
+
135
+ def seq_forward(self, input_dict):
136
+ raise NotImplementedError
137
+
138
+ def train_process(self, output, input_dict):
139
+ pass
140
+
141
+ def inference_forward(self, input_dict):
142
+ if input_dict["sample_method"] == "beam":
143
+ return self.beam_search(input_dict)
144
+ elif input_dict["sample_method"] == "dbs":
145
+ return self.diverse_beam_search(input_dict)
146
+ return self.stepwise_forward(input_dict)
147
+
148
+ def stepwise_forward(self, input_dict):
149
+ """Step-by-step decoding"""
150
+ output = self.prepare_output(input_dict)
151
+ max_length = output["seq"].size(1)
152
+ # start sampling
153
+ for t in range(max_length):
154
+ input_dict["t"] = t
155
+ self.decode_step(input_dict, output)
156
+ if input_dict["mode"] == "inference": # decide whether to stop when sampling
157
+ unfinished_t = output["seq"][:, t] != self.end_idx
158
+ if t == 0:
159
+ unfinished = unfinished_t
160
+ else:
161
+ unfinished *= unfinished_t
162
+ output["seq"][:, t][~unfinished] = self.end_idx
163
+ if unfinished.sum() == 0:
164
+ break
165
+ self.stepwise_process(output)
166
+ return output
167
+
168
+ def decode_step(self, input_dict, output):
169
+ """Decoding operation of timestep t"""
170
+ decoder_input = self.prepare_decoder_input(input_dict, output)
171
+ # feed to the decoder to get logit
172
+ output_t = self.decoder(decoder_input)
173
+ logit_t = output_t["logit"]
174
+ # assert logit_t.ndim == 3
175
+ if logit_t.size(1) == 1:
176
+ logit_t = logit_t.squeeze(1)
177
+ embed_t = output_t["embed"].squeeze(1)
178
+ elif logit_t.size(1) > 1:
179
+ logit_t = logit_t[:, -1, :]
180
+ embed_t = output_t["embed"][:, -1, :]
181
+ else:
182
+ raise Exception("no logit output")
183
+ # sample the next input word and get the corresponding logit
184
+ sampled = self.sample_next_word(logit_t,
185
+ method=input_dict["sample_method"],
186
+ temp=input_dict["temp"])
187
+
188
+ output_t.update(sampled)
189
+ output_t["t"] = input_dict["t"]
190
+ output_t["logit"] = logit_t
191
+ output_t["embed"] = embed_t
192
+ self.stepwise_process_step(output, output_t)
193
+
194
+ def prepare_decoder_input(self, input_dict, output):
195
+ """Prepare the inp ut dict for the decoder"""
196
+ raise NotImplementedError
197
+
198
+ def stepwise_process_step(self, output, output_t):
199
+ """Postprocessing (save output values) after each timestep t"""
200
+ t = output_t["t"]
201
+ output["logit"][:, t, :] = output_t["logit"]
202
+ output["seq"][:, t] = output_t["word"]
203
+ output["sampled_logprob"][:, t] = output_t["probs"]
204
+ output["embed"][:, t, :] = output_t["embed"]
205
+
206
+ def stepwise_process(self, output):
207
+ """Postprocessing after the whole step-by-step autoregressive decoding"""
208
+ pass
209
+
210
+ def sample_next_word(self, logit, method, temp):
211
+ """Sample the next word, given probs output by the decoder"""
212
+ logprob = torch.log_softmax(logit, dim=1)
213
+ if method == "greedy":
214
+ sampled_logprob, word = torch.max(logprob.detach(), 1)
215
+ elif method == "gumbel":
216
+ def sample_gumbel(shape, eps=1e-20):
217
+ U = torch.rand(shape).to(logprob.device)
218
+ return -torch.log(-torch.log(U + eps) + eps)
219
+ def gumbel_softmax_sample(logit, temperature):
220
+ y = logit + sample_gumbel(logit.size())
221
+ return torch.log_softmax(y / temperature, dim=-1)
222
+ _logprob = gumbel_softmax_sample(logprob, temp)
223
+ _, word = torch.max(_logprob.data, 1)
224
+ sampled_logprob = logprob.gather(1, word.unsqueeze(-1))
225
+ else:
226
+ logprob = logprob / temp
227
+ if method.startswith("top"):
228
+ top_num = float(method[3:])
229
+ if 0 < top_num < 1: # top-p sampling
230
+ probs = torch.softmax(logit, dim=1)
231
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=1)
232
+ _cumsum = sorted_probs.cumsum(1)
233
+ mask = _cumsum < top_num
234
+ mask = torch.cat([torch.ones_like(mask[:,:1]), mask[:,:-1]], 1)
235
+ sorted_probs = sorted_probs * mask.to(sorted_probs)
236
+ sorted_probs = sorted_probs / sorted_probs.sum(1, keepdim=True)
237
+ logprob.scatter_(1, sorted_indices, sorted_probs.log())
238
+ else: # top-k sampling
239
+ k = int(top_num)
240
+ tmp = torch.empty_like(logprob).fill_(float('-inf'))
241
+ topk, indices = torch.topk(logprob, k, dim=1)
242
+ tmp = tmp.scatter(1, indices, topk)
243
+ logprob = tmp
244
+ word = torch.distributions.Categorical(logits=logprob.detach()).sample()
245
+ sampled_logprob = logprob.gather(1, word.unsqueeze(-1)).squeeze(1)
246
+ word = word.detach().long()
247
+ # sampled_logprob: [N,], word: [N,]
248
+ return {"word": word, "probs": sampled_logprob}
249
+
250
+ def beam_search(self, input_dict):
251
+ output = self.prepare_output(input_dict)
252
+ max_length = input_dict["max_length"]
253
+ beam_size = input_dict["beam_size"]
254
+ if input_dict["n_best"]:
255
+ n_best_size = input_dict["n_best_size"]
256
+ batch_size, max_length = output["seq"].size()
257
+ output["seq"] = torch.full((batch_size, n_best_size, max_length),
258
+ self.end_idx, dtype=torch.long)
259
+
260
+ temp = input_dict["temp"]
261
+ # instance by instance beam seach
262
+ for i in range(output["seq"].size(0)):
263
+ output_i = self.prepare_beamsearch_output(input_dict)
264
+ input_dict["sample_idx"] = i
265
+ for t in range(max_length):
266
+ input_dict["t"] = t
267
+ output_t = self.beamsearch_step(input_dict, output_i)
268
+ #######################################
269
+ # merge with previous beam and select the current max prob beam
270
+ #######################################
271
+ logit_t = output_t["logit"]
272
+ if logit_t.size(1) == 1:
273
+ logit_t = logit_t.squeeze(1)
274
+ elif logit_t.size(1) > 1:
275
+ logit_t = logit_t[:, -1, :]
276
+ else:
277
+ raise Exception("no logit output")
278
+ logprob_t = torch.log_softmax(logit_t, dim=1)
279
+ logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
280
+ logprob_t = output_i["topk_logprob"].unsqueeze(1) + logprob_t
281
+ if t == 0: # for the first step, all k seq will have the same probs
282
+ topk_logprob, topk_words = logprob_t[0].topk(
283
+ beam_size, 0, True, True)
284
+ else: # unroll and find top logprob, and their unrolled indices
285
+ topk_logprob, topk_words = logprob_t.view(-1).topk(
286
+ beam_size, 0, True, True)
287
+ topk_words = topk_words.cpu()
288
+ output_i["topk_logprob"] = topk_logprob
289
+ # output_i["prev_words_beam"] = topk_words // self.vocab_size # [beam_size,]
290
+ output_i["prev_words_beam"] = torch.div(topk_words, self.vocab_size,
291
+ rounding_mode='trunc')
292
+ output_i["next_word"] = topk_words % self.vocab_size # [beam_size,]
293
+ if t == 0:
294
+ output_i["seq"] = output_i["next_word"].unsqueeze(1)
295
+ else:
296
+ output_i["seq"] = torch.cat([
297
+ output_i["seq"][output_i["prev_words_beam"]],
298
+ output_i["next_word"].unsqueeze(1)], dim=1)
299
+
300
+ # add finished beams to results
301
+ is_end = output_i["next_word"] == self.end_idx
302
+ if t == max_length - 1:
303
+ is_end.fill_(1)
304
+
305
+ for beam_idx in range(beam_size):
306
+ if is_end[beam_idx]:
307
+ final_beam = {
308
+ "seq": output_i["seq"][beam_idx].clone(),
309
+ "score": output_i["topk_logprob"][beam_idx].item()
310
+ }
311
+ final_beam["score"] = final_beam["score"] / (t + 1)
312
+ output_i["done_beams"].append(final_beam)
313
+ output_i["topk_logprob"][is_end] -= 1000
314
+
315
+ self.beamsearch_process_step(output_i, output_t)
316
+
317
+ self.beamsearch_process(output, output_i, input_dict)
318
+ return output
319
+
320
+ def prepare_beamsearch_output(self, input_dict):
321
+ beam_size = input_dict["beam_size"]
322
+ device = input_dict["fc_emb"].device
323
+ output = {
324
+ "topk_logprob": torch.zeros(beam_size).to(device),
325
+ "seq": None,
326
+ "prev_words_beam": None,
327
+ "next_word": None,
328
+ "done_beams": [],
329
+ }
330
+ return output
331
+
332
+ def beamsearch_step(self, input_dict, output_i):
333
+ decoder_input = self.prepare_beamsearch_decoder_input(input_dict, output_i)
334
+ output_t = self.decoder(decoder_input)
335
+ output_t["t"] = input_dict["t"]
336
+ return output_t
337
+
338
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
339
+ raise NotImplementedError
340
+
341
+ def beamsearch_process_step(self, output_i, output_t):
342
+ pass
343
+
344
+ def beamsearch_process(self, output, output_i, input_dict):
345
+ i = input_dict["sample_idx"]
346
+ done_beams = sorted(output_i["done_beams"], key=lambda x: -x["score"])
347
+ if input_dict["n_best"]:
348
+ done_beams = done_beams[:input_dict["n_best_size"]]
349
+ for out_idx, done_beam in enumerate(done_beams):
350
+ seq = done_beam["seq"]
351
+ output["seq"][i][out_idx, :len(seq)] = seq
352
+ else:
353
+ seq = done_beams[0]["seq"]
354
+ output["seq"][i][:len(seq)] = seq
355
+
356
+ def diverse_beam_search(self, input_dict):
357
+
358
+ def add_diversity(seq_table, logprob, t, divm, diversity_lambda, bdash):
359
+ local_time = t - divm
360
+ unaug_logprob = logprob.clone()
361
+
362
+ if divm > 0:
363
+ change = torch.zeros(logprob.size(-1))
364
+ for prev_choice in range(divm):
365
+ prev_decisions = seq_table[prev_choice][..., local_time]
366
+ for prev_labels in range(bdash):
367
+ change.scatter_add_(0, prev_decisions[prev_labels], change.new_ones(1))
368
+
369
+ change = change.to(logprob.device)
370
+ logprob = logprob - repeat_tensor(change, bdash) * diversity_lambda
371
+
372
+ return logprob, unaug_logprob
373
+
374
+ output = self.prepare_output(input_dict)
375
+ group_size = input_dict["group_size"]
376
+ batch_size = output["seq"].size(0)
377
+ beam_size = input_dict["beam_size"]
378
+ bdash = beam_size // group_size
379
+ input_dict["bdash"] = bdash
380
+ diversity_lambda = input_dict["diversity_lambda"]
381
+ device = input_dict["fc_emb"].device
382
+ max_length = input_dict["max_length"]
383
+ temp = input_dict["temp"]
384
+ group_nbest = input_dict["group_nbest"]
385
+ batch_size, max_length = output["seq"].size()
386
+ if group_nbest:
387
+ output["seq"] = torch.full((batch_size, beam_size, max_length),
388
+ self.end_idx, dtype=torch.long)
389
+ else:
390
+ output["seq"] = torch.full((batch_size, group_size, max_length),
391
+ self.end_idx, dtype=torch.long)
392
+
393
+
394
+ for i in range(batch_size):
395
+ input_dict["sample_idx"] = i
396
+ seq_table = [torch.LongTensor(bdash, 0) for _ in range(group_size)] # group_size x [bdash, 0]
397
+ logprob_table = [torch.zeros(bdash).to(device) for _ in range(group_size)]
398
+ done_beams_table = [[] for _ in range(group_size)]
399
+
400
+ output_i = {
401
+ "prev_words_beam": [None for _ in range(group_size)],
402
+ "next_word": [None for _ in range(group_size)],
403
+ "state": [None for _ in range(group_size)]
404
+ }
405
+
406
+ for t in range(max_length + group_size - 1):
407
+ input_dict["t"] = t
408
+ for divm in range(group_size):
409
+ input_dict["divm"] = divm
410
+ if t >= divm and t <= max_length + divm - 1:
411
+ local_time = t - divm
412
+ decoder_input = self.prepare_dbs_decoder_input(input_dict, output_i)
413
+ output_t = self.decoder(decoder_input)
414
+ output_t["divm"] = divm
415
+ logit_t = output_t["logit"]
416
+ if logit_t.size(1) == 1:
417
+ logit_t = logit_t.squeeze(1)
418
+ elif logit_t.size(1) > 1:
419
+ logit_t = logit_t[:, -1, :]
420
+ else:
421
+ raise Exception("no logit output")
422
+ logprob_t = torch.log_softmax(logit_t, dim=1)
423
+ logprob_t = torch.log_softmax(logprob_t / temp, dim=1)
424
+ logprob_t, unaug_logprob_t = add_diversity(seq_table, logprob_t, t, divm, diversity_lambda, bdash)
425
+ logprob_t = logprob_table[divm].unsqueeze(-1) + logprob_t
426
+ if local_time == 0: # for the first step, all k seq will have the same probs
427
+ topk_logprob, topk_words = logprob_t[0].topk(
428
+ bdash, 0, True, True)
429
+ else: # unroll and find top logprob, and their unrolled indices
430
+ topk_logprob, topk_words = logprob_t.view(-1).topk(
431
+ bdash, 0, True, True)
432
+ topk_words = topk_words.cpu()
433
+ logprob_table[divm] = topk_logprob
434
+ output_i["prev_words_beam"][divm] = topk_words // self.vocab_size # [bdash,]
435
+ output_i["next_word"][divm] = topk_words % self.vocab_size # [bdash,]
436
+ if local_time > 0:
437
+ seq_table[divm] = seq_table[divm][output_i["prev_words_beam"][divm]]
438
+ seq_table[divm] = torch.cat([
439
+ seq_table[divm],
440
+ output_i["next_word"][divm].unsqueeze(-1)], -1)
441
+
442
+ is_end = seq_table[divm][:, t-divm] == self.end_idx
443
+ assert seq_table[divm].shape[-1] == t - divm + 1
444
+ if t == max_length + divm - 1:
445
+ is_end.fill_(1)
446
+ for beam_idx in range(bdash):
447
+ if is_end[beam_idx]:
448
+ final_beam = {
449
+ "seq": seq_table[divm][beam_idx].clone(),
450
+ "score": logprob_table[divm][beam_idx].item()
451
+ }
452
+ final_beam["score"] = final_beam["score"] / (t - divm + 1)
453
+ done_beams_table[divm].append(final_beam)
454
+ logprob_table[divm][is_end] -= 1000
455
+ self.dbs_process_step(output_i, output_t)
456
+ done_beams_table = [sorted(done_beams_table[divm], key=lambda x: -x["score"])[:bdash] for divm in range(group_size)]
457
+ if group_nbest:
458
+ done_beams = sum(done_beams_table, [])
459
+ else:
460
+ done_beams = [group_beam[0] for group_beam in done_beams_table]
461
+ for _, done_beam in enumerate(done_beams):
462
+ output["seq"][i, _, :len(done_beam["seq"])] = done_beam["seq"]
463
+
464
+ return output
465
+
466
+ def prepare_dbs_decoder_input(self, input_dict, output_i):
467
+ raise NotImplementedError
468
+
469
+ def dbs_process_step(self, output_i, output_t):
470
+ pass
471
+
472
+
473
+ class CaptionSequenceModel(nn.Module):
474
+
475
+ def __init__(self, model, seq_output_size):
476
+ super().__init__()
477
+ self.model = model
478
+ if model.decoder.d_model != seq_output_size:
479
+ self.output_transform = nn.Linear(model.decoder.d_model, seq_output_size)
480
+ else:
481
+ self.output_transform = lambda x: x
482
+
483
+ def forward(self, input_dict):
484
+ output = self.model(input_dict)
485
+
486
+ if input_dict["mode"] == "train":
487
+ lens = input_dict["cap_len"] - 1
488
+ # seq_outputs: [N, d_model]
489
+ elif input_dict["mode"] == "inference":
490
+ if "sample_method" in input_dict and input_dict["sample_method"] == "beam":
491
+ return output
492
+ seq = output["seq"]
493
+ lens = torch.where(seq == self.model.end_idx, torch.zeros_like(seq), torch.ones_like(seq)).sum(dim=1)
494
+ else:
495
+ raise Exception("mode should be either 'train' or 'inference'")
496
+ seq_output = mean_with_lens(output["embed"], lens)
497
+ seq_output = self.output_transform(seq_output)
498
+ output["seq_output"] = seq_output
499
+ return output
500
+
audio_to_text/captioning/models/decoder.py ADDED
@@ -0,0 +1,746 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import math
4
+ from functools import partial
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from .utils import generate_length_mask, init, PositionalEncoding
11
+
12
+
13
+ class BaseDecoder(nn.Module):
14
+ """
15
+ Take word/audio embeddings and output the next word probs
16
+ Base decoder, cannot be called directly
17
+ All decoders should inherit from this class
18
+ """
19
+
20
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim,
21
+ attn_emb_dim, dropout=0.2):
22
+ super().__init__()
23
+ self.emb_dim = emb_dim
24
+ self.vocab_size = vocab_size
25
+ self.fc_emb_dim = fc_emb_dim
26
+ self.attn_emb_dim = attn_emb_dim
27
+ self.word_embedding = nn.Embedding(vocab_size, emb_dim)
28
+ self.in_dropout = nn.Dropout(dropout)
29
+
30
+ def forward(self, x):
31
+ raise NotImplementedError
32
+
33
+ def load_word_embedding(self, weight, freeze=True):
34
+ embedding = np.load(weight)
35
+ assert embedding.shape[0] == self.vocab_size, "vocabulary size mismatch"
36
+ assert embedding.shape[1] == self.emb_dim, "embed size mismatch"
37
+
38
+ # embeddings = torch.as_tensor(embeddings).float()
39
+ # self.word_embeddings.weight = nn.Parameter(embeddings)
40
+ # for para in self.word_embeddings.parameters():
41
+ # para.requires_grad = tune
42
+ self.word_embedding = nn.Embedding.from_pretrained(embedding,
43
+ freeze=freeze)
44
+
45
+
46
+ class RnnDecoder(BaseDecoder):
47
+
48
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
49
+ dropout, d_model, **kwargs):
50
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
51
+ dropout,)
52
+ self.d_model = d_model
53
+ self.num_layers = kwargs.get('num_layers', 1)
54
+ self.bidirectional = kwargs.get('bidirectional', False)
55
+ self.rnn_type = kwargs.get('rnn_type', "GRU")
56
+ self.classifier = nn.Linear(
57
+ self.d_model * (self.bidirectional + 1), vocab_size)
58
+
59
+ def forward(self, x):
60
+ raise NotImplementedError
61
+
62
+ def init_hidden(self, bs, device):
63
+ num_dire = self.bidirectional + 1
64
+ n_layer = self.num_layers
65
+ hid_dim = self.d_model
66
+ if self.rnn_type == "LSTM":
67
+ return (torch.zeros(num_dire * n_layer, bs, hid_dim).to(device),
68
+ torch.zeros(num_dire * n_layer, bs, hid_dim).to(device))
69
+ else:
70
+ return torch.zeros(num_dire * n_layer, bs, hid_dim).to(device)
71
+
72
+
73
+ class RnnFcDecoder(RnnDecoder):
74
+
75
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, d_model, **kwargs):
76
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, d_model, **kwargs)
77
+ self.model = getattr(nn, self.rnn_type)(
78
+ input_size=self.emb_dim * 2,
79
+ hidden_size=self.d_model,
80
+ batch_first=True,
81
+ num_layers=self.num_layers,
82
+ bidirectional=self.bidirectional)
83
+ self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
84
+ self.apply(init)
85
+
86
+ def forward(self, input_dict):
87
+ word = input_dict["word"]
88
+ state = input_dict.get("state", None)
89
+ fc_emb = input_dict["fc_emb"]
90
+
91
+ word = word.to(fc_emb.device)
92
+ embed = self.in_dropout(self.word_embedding(word))
93
+
94
+ p_fc_emb = self.fc_proj(fc_emb)
95
+ # embed: [N, T, embed_size]
96
+ embed = torch.cat((embed, p_fc_emb), dim=-1)
97
+
98
+ out, state = self.model(embed, state)
99
+ # out: [N, T, hs], states: [num_layers * num_dire, N, hs]
100
+ logits = self.classifier(out)
101
+ output = {
102
+ "state": state,
103
+ "embeds": out,
104
+ "logits": logits
105
+ }
106
+
107
+ return output
108
+
109
+
110
+ class Seq2SeqAttention(nn.Module):
111
+
112
+ def __init__(self, hs_enc, hs_dec, attn_size):
113
+ """
114
+ Args:
115
+ hs_enc: encoder hidden size
116
+ hs_dec: decoder hidden size
117
+ attn_size: attention vector size
118
+ """
119
+ super(Seq2SeqAttention, self).__init__()
120
+ self.h2attn = nn.Linear(hs_enc + hs_dec, attn_size)
121
+ self.v = nn.Parameter(torch.randn(attn_size))
122
+ self.apply(init)
123
+
124
+ def forward(self, h_dec, h_enc, src_lens):
125
+ """
126
+ Args:
127
+ h_dec: decoder hidden (query), [N, hs_dec]
128
+ h_enc: encoder memory (key/value), [N, src_max_len, hs_enc]
129
+ src_lens: source (encoder memory) lengths, [N, ]
130
+ """
131
+ N = h_enc.size(0)
132
+ src_max_len = h_enc.size(1)
133
+ h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec]
134
+
135
+ attn_input = torch.cat((h_dec, h_enc), dim=-1)
136
+ attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size]
137
+
138
+ v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size]
139
+ score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len]
140
+
141
+ idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len)
142
+ mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device)
143
+
144
+ score = score.masked_fill(mask == 0, -1e10)
145
+ weights = torch.softmax(score, dim=-1) # [N, src_max_len]
146
+ ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc]
147
+
148
+ return ctx, weights
149
+
150
+
151
+ class AttentionProj(nn.Module):
152
+
153
+ def __init__(self, hs_enc, hs_dec, embed_dim, attn_size):
154
+ self.q_proj = nn.Linear(hs_dec, embed_dim)
155
+ self.kv_proj = nn.Linear(hs_enc, embed_dim)
156
+ self.h2attn = nn.Linear(embed_dim * 2, attn_size)
157
+ self.v = nn.Parameter(torch.randn(attn_size))
158
+ self.apply(init)
159
+
160
+ def init(self, m):
161
+ if isinstance(m, nn.Linear):
162
+ nn.init.kaiming_uniform_(m.weight)
163
+ if m.bias is not None:
164
+ nn.init.constant_(m.bias, 0)
165
+
166
+ def forward(self, h_dec, h_enc, src_lens):
167
+ """
168
+ Args:
169
+ h_dec: decoder hidden (query), [N, hs_dec]
170
+ h_enc: encoder memory (key/value), [N, src_max_len, hs_enc]
171
+ src_lens: source (encoder memory) lengths, [N, ]
172
+ """
173
+ h_enc = self.kv_proj(h_enc) # [N, src_max_len, embed_dim]
174
+ h_dec = self.q_proj(h_dec) # [N, embed_dim]
175
+ N = h_enc.size(0)
176
+ src_max_len = h_enc.size(1)
177
+ h_dec = h_dec.unsqueeze(1).repeat(1, src_max_len, 1) # [N, src_max_len, hs_dec]
178
+
179
+ attn_input = torch.cat((h_dec, h_enc), dim=-1)
180
+ attn_out = torch.tanh(self.h2attn(attn_input)) # [N, src_max_len, attn_size]
181
+
182
+ v = self.v.repeat(N, 1).unsqueeze(1) # [N, 1, attn_size]
183
+ score = torch.bmm(v, attn_out.transpose(1, 2)).squeeze(1) # [N, src_max_len]
184
+
185
+ idxs = torch.arange(src_max_len).repeat(N).view(N, src_max_len)
186
+ mask = (idxs < src_lens.view(-1, 1)).to(h_dec.device)
187
+
188
+ score = score.masked_fill(mask == 0, -1e10)
189
+ weights = torch.softmax(score, dim=-1) # [N, src_max_len]
190
+ ctx = torch.bmm(weights.unsqueeze(1), h_enc).squeeze(1) # [N, hs_enc]
191
+
192
+ return ctx, weights
193
+
194
+
195
+ class BahAttnDecoder(RnnDecoder):
196
+
197
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
198
+ dropout, d_model, **kwargs):
199
+ """
200
+ concatenate fc, attn, word to feed to the rnn
201
+ """
202
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
203
+ dropout, d_model, **kwargs)
204
+ attn_size = kwargs.get("attn_size", self.d_model)
205
+ self.model = getattr(nn, self.rnn_type)(
206
+ input_size=self.emb_dim * 3,
207
+ hidden_size=self.d_model,
208
+ batch_first=True,
209
+ num_layers=self.num_layers,
210
+ bidirectional=self.bidirectional)
211
+ self.attn = Seq2SeqAttention(self.attn_emb_dim,
212
+ self.d_model * (self.bidirectional + 1) * \
213
+ self.num_layers,
214
+ attn_size)
215
+ self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
216
+ self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
217
+ self.apply(init)
218
+
219
+ def forward(self, input_dict):
220
+ word = input_dict["word"]
221
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
222
+ fc_emb = input_dict["fc_emb"]
223
+ attn_emb = input_dict["attn_emb"]
224
+ attn_emb_len = input_dict["attn_emb_len"]
225
+
226
+ word = word.to(fc_emb.device)
227
+ embed = self.in_dropout(self.word_embedding(word))
228
+
229
+ # embed: [N, 1, embed_size]
230
+ if state is None:
231
+ state = self.init_hidden(word.size(0), fc_emb.device)
232
+ if self.rnn_type == "LSTM":
233
+ query = state[0].transpose(0, 1).flatten(1)
234
+ else:
235
+ query = state.transpose(0, 1).flatten(1)
236
+ c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
237
+
238
+ p_fc_emb = self.fc_proj(fc_emb)
239
+ p_ctx = self.ctx_proj(c)
240
+ rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), p_fc_emb.unsqueeze(1)),
241
+ dim=-1)
242
+
243
+ out, state = self.model(rnn_input, state)
244
+
245
+ output = {
246
+ "state": state,
247
+ "embed": out,
248
+ "logit": self.classifier(out),
249
+ "attn_weight": attn_weight
250
+ }
251
+ return output
252
+
253
+
254
+ class BahAttnDecoder2(RnnDecoder):
255
+
256
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
257
+ dropout, d_model, **kwargs):
258
+ """
259
+ add fc, attn, word together to feed to the rnn
260
+ """
261
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
262
+ dropout, d_model, **kwargs)
263
+ attn_size = kwargs.get("attn_size", self.d_model)
264
+ self.model = getattr(nn, self.rnn_type)(
265
+ input_size=self.emb_dim,
266
+ hidden_size=self.d_model,
267
+ batch_first=True,
268
+ num_layers=self.num_layers,
269
+ bidirectional=self.bidirectional)
270
+ self.attn = Seq2SeqAttention(self.emb_dim,
271
+ self.d_model * (self.bidirectional + 1) * \
272
+ self.num_layers,
273
+ attn_size)
274
+ self.fc_proj = nn.Linear(self.fc_emb_dim, self.emb_dim)
275
+ self.attn_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
276
+ self.apply(partial(init, method="xavier"))
277
+
278
+ def forward(self, input_dict):
279
+ word = input_dict["word"]
280
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
281
+ fc_emb = input_dict["fc_emb"]
282
+ attn_emb = input_dict["attn_emb"]
283
+ attn_emb_len = input_dict["attn_emb_len"]
284
+
285
+ word = word.to(fc_emb.device)
286
+ embed = self.in_dropout(self.word_embedding(word))
287
+ p_attn_emb = self.attn_proj(attn_emb)
288
+
289
+ # embed: [N, 1, embed_size]
290
+ if state is None:
291
+ state = self.init_hidden(word.size(0), fc_emb.device)
292
+ if self.rnn_type == "LSTM":
293
+ query = state[0].transpose(0, 1).flatten(1)
294
+ else:
295
+ query = state.transpose(0, 1).flatten(1)
296
+ c, attn_weight = self.attn(query, p_attn_emb, attn_emb_len)
297
+
298
+ p_fc_emb = self.fc_proj(fc_emb)
299
+ rnn_input = embed + c.unsqueeze(1) + p_fc_emb.unsqueeze(1)
300
+
301
+ out, state = self.model(rnn_input, state)
302
+
303
+ output = {
304
+ "state": state,
305
+ "embed": out,
306
+ "logit": self.classifier(out),
307
+ "attn_weight": attn_weight
308
+ }
309
+ return output
310
+
311
+
312
+ class ConditionalBahAttnDecoder(RnnDecoder):
313
+
314
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
315
+ dropout, d_model, **kwargs):
316
+ """
317
+ concatenate fc, attn, word to feed to the rnn
318
+ """
319
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
320
+ dropout, d_model, **kwargs)
321
+ attn_size = kwargs.get("attn_size", self.d_model)
322
+ self.model = getattr(nn, self.rnn_type)(
323
+ input_size=self.emb_dim * 3,
324
+ hidden_size=self.d_model,
325
+ batch_first=True,
326
+ num_layers=self.num_layers,
327
+ bidirectional=self.bidirectional)
328
+ self.attn = Seq2SeqAttention(self.attn_emb_dim,
329
+ self.d_model * (self.bidirectional + 1) * \
330
+ self.num_layers,
331
+ attn_size)
332
+ self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
333
+ self.condition_embedding = nn.Embedding(2, emb_dim)
334
+ self.apply(init)
335
+
336
+ def forward(self, input_dict):
337
+ word = input_dict["word"]
338
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
339
+ fc_emb = input_dict["fc_emb"]
340
+ attn_emb = input_dict["attn_emb"]
341
+ attn_emb_len = input_dict["attn_emb_len"]
342
+ condition = input_dict["condition"]
343
+
344
+ word = word.to(fc_emb.device)
345
+ embed = self.in_dropout(self.word_embedding(word))
346
+
347
+ condition = torch.as_tensor([[1 - c, c] for c in condition]).to(fc_emb.device)
348
+ condition_emb = torch.matmul(condition, self.condition_embedding.weight)
349
+ # condition_embs: [N, emb_dim]
350
+
351
+ # embed: [N, 1, embed_size]
352
+ if state is None:
353
+ state = self.init_hidden(word.size(0), fc_emb.device)
354
+ if self.rnn_type == "LSTM":
355
+ query = state[0].transpose(0, 1).flatten(1)
356
+ else:
357
+ query = state.transpose(0, 1).flatten(1)
358
+ c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
359
+
360
+ p_ctx = self.ctx_proj(c)
361
+ rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), condition_emb.unsqueeze(1)),
362
+ dim=-1)
363
+
364
+ out, state = self.model(rnn_input, state)
365
+
366
+ output = {
367
+ "state": state,
368
+ "embed": out,
369
+ "logit": self.classifier(out),
370
+ "attn_weight": attn_weight
371
+ }
372
+ return output
373
+
374
+
375
+ class StructBahAttnDecoder(RnnDecoder):
376
+
377
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, struct_vocab_size,
378
+ attn_emb_dim, dropout, d_model, **kwargs):
379
+ """
380
+ concatenate fc, attn, word to feed to the rnn
381
+ """
382
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
383
+ dropout, d_model, **kwargs)
384
+ attn_size = kwargs.get("attn_size", self.d_model)
385
+ self.model = getattr(nn, self.rnn_type)(
386
+ input_size=self.emb_dim * 3,
387
+ hidden_size=self.d_model,
388
+ batch_first=True,
389
+ num_layers=self.num_layers,
390
+ bidirectional=self.bidirectional)
391
+ self.attn = Seq2SeqAttention(self.attn_emb_dim,
392
+ self.d_model * (self.bidirectional + 1) * \
393
+ self.num_layers,
394
+ attn_size)
395
+ self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
396
+ self.struct_embedding = nn.Embedding(struct_vocab_size, emb_dim)
397
+ self.apply(init)
398
+
399
+ def forward(self, input_dict):
400
+ word = input_dict["word"]
401
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
402
+ fc_emb = input_dict["fc_emb"]
403
+ attn_emb = input_dict["attn_emb"]
404
+ attn_emb_len = input_dict["attn_emb_len"]
405
+ structure = input_dict["structure"]
406
+
407
+ word = word.to(fc_emb.device)
408
+ embed = self.in_dropout(self.word_embedding(word))
409
+
410
+ struct_emb = self.struct_embedding(structure)
411
+ # struct_embs: [N, emb_dim]
412
+
413
+ # embed: [N, 1, embed_size]
414
+ if state is None:
415
+ state = self.init_hidden(word.size(0), fc_emb.device)
416
+ if self.rnn_type == "LSTM":
417
+ query = state[0].transpose(0, 1).flatten(1)
418
+ else:
419
+ query = state.transpose(0, 1).flatten(1)
420
+ c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
421
+
422
+ p_ctx = self.ctx_proj(c)
423
+ rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), struct_emb.unsqueeze(1)), dim=-1)
424
+
425
+ out, state = self.model(rnn_input, state)
426
+
427
+ output = {
428
+ "state": state,
429
+ "embed": out,
430
+ "logit": self.classifier(out),
431
+ "attn_weight": attn_weight
432
+ }
433
+ return output
434
+
435
+
436
+ class StyleBahAttnDecoder(RnnDecoder):
437
+
438
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
439
+ dropout, d_model, **kwargs):
440
+ """
441
+ concatenate fc, attn, word to feed to the rnn
442
+ """
443
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
444
+ dropout, d_model, **kwargs)
445
+ attn_size = kwargs.get("attn_size", self.d_model)
446
+ self.model = getattr(nn, self.rnn_type)(
447
+ input_size=self.emb_dim * 3,
448
+ hidden_size=self.d_model,
449
+ batch_first=True,
450
+ num_layers=self.num_layers,
451
+ bidirectional=self.bidirectional)
452
+ self.attn = Seq2SeqAttention(self.attn_emb_dim,
453
+ self.d_model * (self.bidirectional + 1) * \
454
+ self.num_layers,
455
+ attn_size)
456
+ self.ctx_proj = nn.Linear(self.attn_emb_dim, self.emb_dim)
457
+ self.apply(init)
458
+
459
+ def forward(self, input_dict):
460
+ word = input_dict["word"]
461
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
462
+ fc_emb = input_dict["fc_emb"]
463
+ attn_emb = input_dict["attn_emb"]
464
+ attn_emb_len = input_dict["attn_emb_len"]
465
+ style = input_dict["style"]
466
+
467
+ word = word.to(fc_emb.device)
468
+ embed = self.in_dropout(self.word_embedding(word))
469
+
470
+ # embed: [N, 1, embed_size]
471
+ if state is None:
472
+ state = self.init_hidden(word.size(0), fc_emb.device)
473
+ if self.rnn_type == "LSTM":
474
+ query = state[0].transpose(0, 1).flatten(1)
475
+ else:
476
+ query = state.transpose(0, 1).flatten(1)
477
+ c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
478
+
479
+ p_ctx = self.ctx_proj(c)
480
+ rnn_input = torch.cat((embed, p_ctx.unsqueeze(1), style.unsqueeze(1)),
481
+ dim=-1)
482
+
483
+ out, state = self.model(rnn_input, state)
484
+
485
+ output = {
486
+ "state": state,
487
+ "embed": out,
488
+ "logit": self.classifier(out),
489
+ "attn_weight": attn_weight
490
+ }
491
+ return output
492
+
493
+
494
+ class BahAttnDecoder3(RnnDecoder):
495
+
496
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
497
+ dropout, d_model, **kwargs):
498
+ """
499
+ concatenate fc, attn, word to feed to the rnn
500
+ """
501
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
502
+ dropout, d_model, **kwargs)
503
+ attn_size = kwargs.get("attn_size", self.d_model)
504
+ self.model = getattr(nn, self.rnn_type)(
505
+ input_size=self.emb_dim + attn_emb_dim,
506
+ hidden_size=self.d_model,
507
+ batch_first=True,
508
+ num_layers=self.num_layers,
509
+ bidirectional=self.bidirectional)
510
+ self.attn = Seq2SeqAttention(self.attn_emb_dim,
511
+ self.d_model * (self.bidirectional + 1) * \
512
+ self.num_layers,
513
+ attn_size)
514
+ self.ctx_proj = lambda x: x
515
+ self.apply(init)
516
+
517
+ def forward(self, input_dict):
518
+ word = input_dict["word"]
519
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
520
+ fc_emb = input_dict["fc_emb"]
521
+ attn_emb = input_dict["attn_emb"]
522
+ attn_emb_len = input_dict["attn_emb_len"]
523
+
524
+ if word.size(-1) == self.fc_emb_dim: # fc_emb
525
+ embed = word.unsqueeze(1)
526
+ elif word.size(-1) == 1: # word
527
+ word = word.to(fc_emb.device)
528
+ embed = self.in_dropout(self.word_embedding(word))
529
+ else:
530
+ raise Exception(f"problem with word input size {word.size()}")
531
+
532
+ # embed: [N, 1, embed_size]
533
+ if state is None:
534
+ state = self.init_hidden(word.size(0), fc_emb.device)
535
+ if self.rnn_type == "LSTM":
536
+ query = state[0].transpose(0, 1).flatten(1)
537
+ else:
538
+ query = state.transpose(0, 1).flatten(1)
539
+ c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
540
+
541
+ p_ctx = self.ctx_proj(c)
542
+ rnn_input = torch.cat((embed, p_ctx.unsqueeze(1)), dim=-1)
543
+
544
+ out, state = self.model(rnn_input, state)
545
+
546
+ output = {
547
+ "state": state,
548
+ "embed": out,
549
+ "logit": self.classifier(out),
550
+ "attn_weight": attn_weight
551
+ }
552
+ return output
553
+
554
+
555
+ class SpecificityBahAttnDecoder(RnnDecoder):
556
+
557
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
558
+ dropout, d_model, **kwargs):
559
+ """
560
+ concatenate fc, attn, word to feed to the rnn
561
+ """
562
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
563
+ dropout, d_model, **kwargs)
564
+ attn_size = kwargs.get("attn_size", self.d_model)
565
+ self.model = getattr(nn, self.rnn_type)(
566
+ input_size=self.emb_dim + attn_emb_dim + 1,
567
+ hidden_size=self.d_model,
568
+ batch_first=True,
569
+ num_layers=self.num_layers,
570
+ bidirectional=self.bidirectional)
571
+ self.attn = Seq2SeqAttention(self.attn_emb_dim,
572
+ self.d_model * (self.bidirectional + 1) * \
573
+ self.num_layers,
574
+ attn_size)
575
+ self.ctx_proj = lambda x: x
576
+ self.apply(init)
577
+
578
+ def forward(self, input_dict):
579
+ word = input_dict["word"]
580
+ state = input_dict.get("state", None) # [n_layer * n_dire, bs, d_model]
581
+ fc_emb = input_dict["fc_emb"]
582
+ attn_emb = input_dict["attn_emb"]
583
+ attn_emb_len = input_dict["attn_emb_len"]
584
+ condition = input_dict["condition"] # [N,]
585
+
586
+ word = word.to(fc_emb.device)
587
+ embed = self.in_dropout(self.word_embedding(word))
588
+
589
+ # embed: [N, 1, embed_size]
590
+ if state is None:
591
+ state = self.init_hidden(word.size(0), fc_emb.device)
592
+ if self.rnn_type == "LSTM":
593
+ query = state[0].transpose(0, 1).flatten(1)
594
+ else:
595
+ query = state.transpose(0, 1).flatten(1)
596
+ c, attn_weight = self.attn(query, attn_emb, attn_emb_len)
597
+
598
+ p_ctx = self.ctx_proj(c)
599
+ rnn_input = torch.cat(
600
+ (embed, p_ctx.unsqueeze(1), condition.reshape(-1, 1, 1)),
601
+ dim=-1)
602
+
603
+ out, state = self.model(rnn_input, state)
604
+
605
+ output = {
606
+ "state": state,
607
+ "embed": out,
608
+ "logit": self.classifier(out),
609
+ "attn_weight": attn_weight
610
+ }
611
+ return output
612
+
613
+
614
+ class TransformerDecoder(BaseDecoder):
615
+
616
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim, dropout, **kwargs):
617
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
618
+ dropout=dropout,)
619
+ self.d_model = emb_dim
620
+ self.nhead = kwargs.get("nhead", self.d_model // 64)
621
+ self.nlayers = kwargs.get("nlayers", 2)
622
+ self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
623
+
624
+ self.pos_encoder = PositionalEncoding(self.d_model, dropout)
625
+ layer = nn.TransformerDecoderLayer(d_model=self.d_model,
626
+ nhead=self.nhead,
627
+ dim_feedforward=self.dim_feedforward,
628
+ dropout=dropout)
629
+ self.model = nn.TransformerDecoder(layer, self.nlayers)
630
+ self.classifier = nn.Linear(self.d_model, vocab_size)
631
+ self.attn_proj = nn.Sequential(
632
+ nn.Linear(self.attn_emb_dim, self.d_model),
633
+ nn.ReLU(),
634
+ nn.Dropout(dropout),
635
+ nn.LayerNorm(self.d_model)
636
+ )
637
+ # self.attn_proj = lambda x: x
638
+ self.init_params()
639
+
640
+ def init_params(self):
641
+ for p in self.parameters():
642
+ if p.dim() > 1:
643
+ nn.init.xavier_uniform_(p)
644
+
645
+ def generate_square_subsequent_mask(self, max_length):
646
+ mask = (torch.triu(torch.ones(max_length, max_length)) == 1).transpose(0, 1)
647
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
648
+ return mask
649
+
650
+ def forward(self, input_dict):
651
+ word = input_dict["word"]
652
+ attn_emb = input_dict["attn_emb"]
653
+ attn_emb_len = input_dict["attn_emb_len"]
654
+ cap_padding_mask = input_dict["cap_padding_mask"]
655
+
656
+ p_attn_emb = self.attn_proj(attn_emb)
657
+ p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
658
+ word = word.to(attn_emb.device)
659
+ embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
660
+ embed = embed.transpose(0, 1) # [T, N, emb_dim]
661
+ embed = self.pos_encoder(embed)
662
+
663
+ tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
664
+ memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
665
+ output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
666
+ tgt_key_padding_mask=cap_padding_mask,
667
+ memory_key_padding_mask=memory_key_padding_mask)
668
+ output = output.transpose(0, 1)
669
+ output = {
670
+ "embed": output,
671
+ "logit": self.classifier(output),
672
+ }
673
+ return output
674
+
675
+
676
+
677
+
678
+ class EventTransformerDecoder(TransformerDecoder):
679
+
680
+ def forward(self, input_dict):
681
+ word = input_dict["word"] # index of word embeddings
682
+ attn_emb = input_dict["attn_emb"]
683
+ attn_emb_len = input_dict["attn_emb_len"]
684
+ cap_padding_mask = input_dict["cap_padding_mask"]
685
+ event_emb = input_dict["event"] # [N, emb_dim]
686
+
687
+ p_attn_emb = self.attn_proj(attn_emb)
688
+ p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
689
+ word = word.to(attn_emb.device)
690
+ embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
691
+
692
+ embed = embed.transpose(0, 1) # [T, N, emb_dim]
693
+ embed += event_emb
694
+ embed = self.pos_encoder(embed)
695
+
696
+ tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
697
+ memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
698
+ output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
699
+ tgt_key_padding_mask=cap_padding_mask,
700
+ memory_key_padding_mask=memory_key_padding_mask)
701
+ output = output.transpose(0, 1)
702
+ output = {
703
+ "embed": output,
704
+ "logit": self.classifier(output),
705
+ }
706
+ return output
707
+
708
+
709
+ class KeywordProbTransformerDecoder(TransformerDecoder):
710
+
711
+ def __init__(self, emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
712
+ dropout, keyword_classes_num, **kwargs):
713
+ super().__init__(emb_dim, vocab_size, fc_emb_dim, attn_emb_dim,
714
+ dropout, **kwargs)
715
+ self.keyword_proj = nn.Linear(keyword_classes_num, self.d_model)
716
+ self.word_keyword_norm = nn.LayerNorm(self.d_model)
717
+
718
+ def forward(self, input_dict):
719
+ word = input_dict["word"] # index of word embeddings
720
+ attn_emb = input_dict["attn_emb"]
721
+ attn_emb_len = input_dict["attn_emb_len"]
722
+ cap_padding_mask = input_dict["cap_padding_mask"]
723
+ keyword = input_dict["keyword"] # [N, keyword_classes_num]
724
+
725
+ p_attn_emb = self.attn_proj(attn_emb)
726
+ p_attn_emb = p_attn_emb.transpose(0, 1) # [T_src, N, emb_dim]
727
+ word = word.to(attn_emb.device)
728
+ embed = self.in_dropout(self.word_embedding(word)) * math.sqrt(self.emb_dim) # [N, T, emb_dim]
729
+
730
+ embed = embed.transpose(0, 1) # [T, N, emb_dim]
731
+ embed += self.keyword_proj(keyword)
732
+ embed = self.word_keyword_norm(embed)
733
+
734
+ embed = self.pos_encoder(embed)
735
+
736
+ tgt_mask = self.generate_square_subsequent_mask(embed.size(0)).to(attn_emb.device)
737
+ memory_key_padding_mask = ~generate_length_mask(attn_emb_len, attn_emb.size(1)).to(attn_emb.device)
738
+ output = self.model(embed, p_attn_emb, tgt_mask=tgt_mask,
739
+ tgt_key_padding_mask=cap_padding_mask,
740
+ memory_key_padding_mask=memory_key_padding_mask)
741
+ output = output.transpose(0, 1)
742
+ output = {
743
+ "embed": output,
744
+ "logit": self.classifier(output),
745
+ }
746
+ return output
audio_to_text/captioning/models/encoder.py ADDED
@@ -0,0 +1,686 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import math
4
+ import copy
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torchaudio import transforms
10
+ from torchlibrosa.augmentation import SpecAugmentation
11
+
12
+ from .utils import mean_with_lens, max_with_lens, \
13
+ init, pack_wrapper, generate_length_mask, PositionalEncoding
14
+
15
+
16
+ def init_layer(layer):
17
+ """Initialize a Linear or Convolutional layer. """
18
+ nn.init.xavier_uniform_(layer.weight)
19
+
20
+ if hasattr(layer, 'bias'):
21
+ if layer.bias is not None:
22
+ layer.bias.data.fill_(0.)
23
+
24
+
25
+ def init_bn(bn):
26
+ """Initialize a Batchnorm layer. """
27
+ bn.bias.data.fill_(0.)
28
+ bn.weight.data.fill_(1.)
29
+
30
+
31
+ class BaseEncoder(nn.Module):
32
+
33
+ """
34
+ Encode the given audio into embedding
35
+ Base encoder class, cannot be called directly
36
+ All encoders should inherit from this class
37
+ """
38
+
39
+ def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim):
40
+ super(BaseEncoder, self).__init__()
41
+ self.spec_dim = spec_dim
42
+ self.fc_feat_dim = fc_feat_dim
43
+ self.attn_feat_dim = attn_feat_dim
44
+
45
+
46
+ def forward(self, x):
47
+ #########################
48
+ # an encoder first encodes audio feature into embedding, obtaining
49
+ # `encoded`: {
50
+ # fc_embs: [N, fc_emb_dim],
51
+ # attn_embs: [N, attn_max_len, attn_emb_dim],
52
+ # attn_emb_lens: [N,]
53
+ # }
54
+ #########################
55
+ raise NotImplementedError
56
+
57
+
58
+ class Block2D(nn.Module):
59
+
60
+ def __init__(self, cin, cout, kernel_size=3, padding=1):
61
+ super().__init__()
62
+ self.block = nn.Sequential(
63
+ nn.BatchNorm2d(cin),
64
+ nn.Conv2d(cin,
65
+ cout,
66
+ kernel_size=kernel_size,
67
+ padding=padding,
68
+ bias=False),
69
+ nn.LeakyReLU(inplace=True, negative_slope=0.1))
70
+
71
+ def forward(self, x):
72
+ return self.block(x)
73
+
74
+
75
+ class LinearSoftPool(nn.Module):
76
+ """LinearSoftPool
77
+ Linear softmax, takes logits and returns a probability, near to the actual maximum value.
78
+ Taken from the paper:
79
+ A Comparison of Five Multiple Instance Learning Pooling Functions for Sound Event Detection with Weak Labeling
80
+ https://arxiv.org/abs/1810.09050
81
+ """
82
+ def __init__(self, pooldim=1):
83
+ super().__init__()
84
+ self.pooldim = pooldim
85
+
86
+ def forward(self, logits, time_decision):
87
+ return (time_decision**2).sum(self.pooldim) / time_decision.sum(
88
+ self.pooldim)
89
+
90
+
91
+ class MeanPool(nn.Module):
92
+
93
+ def __init__(self, pooldim=1):
94
+ super().__init__()
95
+ self.pooldim = pooldim
96
+
97
+ def forward(self, logits, decision):
98
+ return torch.mean(decision, dim=self.pooldim)
99
+
100
+
101
+ class AttentionPool(nn.Module):
102
+ """docstring for AttentionPool"""
103
+ def __init__(self, inputdim, outputdim=10, pooldim=1, **kwargs):
104
+ super().__init__()
105
+ self.inputdim = inputdim
106
+ self.outputdim = outputdim
107
+ self.pooldim = pooldim
108
+ self.transform = nn.Linear(inputdim, outputdim)
109
+ self.activ = nn.Softmax(dim=self.pooldim)
110
+ self.eps = 1e-7
111
+
112
+ def forward(self, logits, decision):
113
+ # Input is (B, T, D)
114
+ # B, T, D
115
+ w = self.activ(torch.clamp(self.transform(logits), -15, 15))
116
+ detect = (decision * w).sum(
117
+ self.pooldim) / (w.sum(self.pooldim) + self.eps)
118
+ # B, T, D
119
+ return detect
120
+
121
+
122
+ class MMPool(nn.Module):
123
+
124
+ def __init__(self, dims):
125
+ super().__init__()
126
+ self.avgpool = nn.AvgPool2d(dims)
127
+ self.maxpool = nn.MaxPool2d(dims)
128
+
129
+ def forward(self, x):
130
+ return self.avgpool(x) + self.maxpool(x)
131
+
132
+
133
+ def parse_poolingfunction(poolingfunction_name='mean', **kwargs):
134
+ """parse_poolingfunction
135
+ A heler function to parse any temporal pooling
136
+ Pooling is done on dimension 1
137
+ :param poolingfunction_name:
138
+ :param **kwargs:
139
+ """
140
+ poolingfunction_name = poolingfunction_name.lower()
141
+ if poolingfunction_name == 'mean':
142
+ return MeanPool(pooldim=1)
143
+ elif poolingfunction_name == 'linear':
144
+ return LinearSoftPool(pooldim=1)
145
+ elif poolingfunction_name == 'attention':
146
+ return AttentionPool(inputdim=kwargs['inputdim'],
147
+ outputdim=kwargs['outputdim'])
148
+
149
+
150
+ def embedding_pooling(x, lens, pooling="mean"):
151
+ if pooling == "max":
152
+ fc_embs = max_with_lens(x, lens)
153
+ elif pooling == "mean":
154
+ fc_embs = mean_with_lens(x, lens)
155
+ elif pooling == "mean+max":
156
+ x_mean = mean_with_lens(x, lens)
157
+ x_max = max_with_lens(x, lens)
158
+ fc_embs = x_mean + x_max
159
+ elif pooling == "last":
160
+ indices = (lens - 1).reshape(-1, 1, 1).repeat(1, 1, x.size(-1))
161
+ # indices: [N, 1, hidden]
162
+ fc_embs = torch.gather(x, 1, indices).squeeze(1)
163
+ else:
164
+ raise Exception(f"pooling method {pooling} not support")
165
+ return fc_embs
166
+
167
+
168
+ class Cdur5Encoder(BaseEncoder):
169
+
170
+ def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"):
171
+ super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
172
+ self.pooling = pooling
173
+ self.features = nn.Sequential(
174
+ Block2D(1, 32),
175
+ nn.LPPool2d(4, (2, 4)),
176
+ Block2D(32, 128),
177
+ Block2D(128, 128),
178
+ nn.LPPool2d(4, (2, 4)),
179
+ Block2D(128, 128),
180
+ Block2D(128, 128),
181
+ nn.LPPool2d(4, (1, 4)),
182
+ nn.Dropout(0.3),
183
+ )
184
+ with torch.no_grad():
185
+ rnn_input_dim = self.features(
186
+ torch.randn(1, 1, 500, spec_dim)).shape
187
+ rnn_input_dim = rnn_input_dim[1] * rnn_input_dim[-1]
188
+
189
+ self.gru = nn.GRU(rnn_input_dim,
190
+ 128,
191
+ bidirectional=True,
192
+ batch_first=True)
193
+ self.apply(init)
194
+
195
+ def forward(self, input_dict):
196
+ x = input_dict["spec"]
197
+ lens = input_dict["spec_len"]
198
+ if "upsample" not in input_dict:
199
+ input_dict["upsample"] = False
200
+ lens = torch.as_tensor(copy.deepcopy(lens))
201
+ N, T, _ = x.shape
202
+ x = x.unsqueeze(1)
203
+ x = self.features(x)
204
+ x = x.transpose(1, 2).contiguous().flatten(-2)
205
+ x, _ = self.gru(x)
206
+ if input_dict["upsample"]:
207
+ x = nn.functional.interpolate(
208
+ x.transpose(1, 2),
209
+ T,
210
+ mode='linear',
211
+ align_corners=False).transpose(1, 2)
212
+ else:
213
+ lens //= 4
214
+ attn_emb = x
215
+ fc_emb = embedding_pooling(x, lens, self.pooling)
216
+ return {
217
+ "attn_emb": attn_emb,
218
+ "fc_emb": fc_emb,
219
+ "attn_emb_len": lens
220
+ }
221
+
222
+
223
+ def conv_conv_block(in_channel, out_channel):
224
+ return nn.Sequential(
225
+ nn.Conv2d(in_channel,
226
+ out_channel,
227
+ kernel_size=3,
228
+ bias=False,
229
+ padding=1),
230
+ nn.BatchNorm2d(out_channel),
231
+ nn.ReLU(True),
232
+ nn.Conv2d(out_channel,
233
+ out_channel,
234
+ kernel_size=3,
235
+ bias=False,
236
+ padding=1),
237
+ nn.BatchNorm2d(out_channel),
238
+ nn.ReLU(True)
239
+ )
240
+
241
+
242
+ class Cdur8Encoder(BaseEncoder):
243
+
244
+ def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, pooling="mean"):
245
+ super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
246
+ self.pooling = pooling
247
+ self.features = nn.Sequential(
248
+ conv_conv_block(1, 64),
249
+ MMPool((2, 2)),
250
+ nn.Dropout(0.2, True),
251
+ conv_conv_block(64, 128),
252
+ MMPool((2, 2)),
253
+ nn.Dropout(0.2, True),
254
+ conv_conv_block(128, 256),
255
+ MMPool((1, 2)),
256
+ nn.Dropout(0.2, True),
257
+ conv_conv_block(256, 512),
258
+ MMPool((1, 2)),
259
+ nn.Dropout(0.2, True),
260
+ nn.AdaptiveAvgPool2d((None, 1)),
261
+ )
262
+ self.init_bn = nn.BatchNorm2d(spec_dim)
263
+ self.embedding = nn.Linear(512, 512)
264
+ self.gru = nn.GRU(512, 256, bidirectional=True, batch_first=True)
265
+ self.apply(init)
266
+
267
+ def forward(self, input_dict):
268
+ x = input_dict["spec"]
269
+ lens = input_dict["spec_len"]
270
+ lens = torch.as_tensor(copy.deepcopy(lens))
271
+ x = x.unsqueeze(1) # B x 1 x T x D
272
+ x = x.transpose(1, 3)
273
+ x = self.init_bn(x)
274
+ x = x.transpose(1, 3)
275
+ x = self.features(x)
276
+ x = x.transpose(1, 2).contiguous().flatten(-2)
277
+ x = F.dropout(x, p=0.5, training=self.training)
278
+ x = F.relu_(self.embedding(x))
279
+ x, _ = self.gru(x)
280
+ attn_emb = x
281
+ lens //= 4
282
+ fc_emb = embedding_pooling(x, lens, self.pooling)
283
+ return {
284
+ "attn_emb": attn_emb,
285
+ "fc_emb": fc_emb,
286
+ "attn_emb_len": lens
287
+ }
288
+
289
+
290
+ class Cnn10Encoder(BaseEncoder):
291
+
292
+ def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim):
293
+ super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
294
+ self.features = nn.Sequential(
295
+ conv_conv_block(1, 64),
296
+ nn.AvgPool2d((2, 2)),
297
+ nn.Dropout(0.2, True),
298
+ conv_conv_block(64, 128),
299
+ nn.AvgPool2d((2, 2)),
300
+ nn.Dropout(0.2, True),
301
+ conv_conv_block(128, 256),
302
+ nn.AvgPool2d((2, 2)),
303
+ nn.Dropout(0.2, True),
304
+ conv_conv_block(256, 512),
305
+ nn.AvgPool2d((2, 2)),
306
+ nn.Dropout(0.2, True),
307
+ nn.AdaptiveAvgPool2d((None, 1)),
308
+ )
309
+ self.init_bn = nn.BatchNorm2d(spec_dim)
310
+ self.embedding = nn.Linear(512, 512)
311
+ self.apply(init)
312
+
313
+ def forward(self, input_dict):
314
+ x = input_dict["spec"]
315
+ lens = input_dict["spec_len"]
316
+ lens = torch.as_tensor(copy.deepcopy(lens))
317
+ x = x.unsqueeze(1) # [N, 1, T, D]
318
+ x = x.transpose(1, 3)
319
+ x = self.init_bn(x)
320
+ x = x.transpose(1, 3)
321
+ x = self.features(x) # [N, 512, T/16, 1]
322
+ x = x.transpose(1, 2).contiguous().flatten(-2) # [N, T/16, 512]
323
+ attn_emb = x
324
+ lens //= 16
325
+ fc_emb = embedding_pooling(x, lens, "mean+max")
326
+ fc_emb = F.dropout(fc_emb, p=0.5, training=self.training)
327
+ fc_emb = self.embedding(fc_emb)
328
+ fc_emb = F.relu_(fc_emb)
329
+ return {
330
+ "attn_emb": attn_emb,
331
+ "fc_emb": fc_emb,
332
+ "attn_emb_len": lens
333
+ }
334
+
335
+
336
+ class ConvBlock(nn.Module):
337
+ def __init__(self, in_channels, out_channels):
338
+
339
+ super(ConvBlock, self).__init__()
340
+
341
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
342
+ out_channels=out_channels,
343
+ kernel_size=(3, 3), stride=(1, 1),
344
+ padding=(1, 1), bias=False)
345
+
346
+ self.conv2 = nn.Conv2d(in_channels=out_channels,
347
+ out_channels=out_channels,
348
+ kernel_size=(3, 3), stride=(1, 1),
349
+ padding=(1, 1), bias=False)
350
+
351
+ self.bn1 = nn.BatchNorm2d(out_channels)
352
+ self.bn2 = nn.BatchNorm2d(out_channels)
353
+
354
+ self.init_weight()
355
+
356
+ def init_weight(self):
357
+ init_layer(self.conv1)
358
+ init_layer(self.conv2)
359
+ init_bn(self.bn1)
360
+ init_bn(self.bn2)
361
+
362
+
363
+ def forward(self, input, pool_size=(2, 2), pool_type='avg'):
364
+
365
+ x = input
366
+ x = F.relu_(self.bn1(self.conv1(x)))
367
+ x = F.relu_(self.bn2(self.conv2(x)))
368
+ if pool_type == 'max':
369
+ x = F.max_pool2d(x, kernel_size=pool_size)
370
+ elif pool_type == 'avg':
371
+ x = F.avg_pool2d(x, kernel_size=pool_size)
372
+ elif pool_type == 'avg+max':
373
+ x1 = F.avg_pool2d(x, kernel_size=pool_size)
374
+ x2 = F.max_pool2d(x, kernel_size=pool_size)
375
+ x = x1 + x2
376
+ else:
377
+ raise Exception('Incorrect argument!')
378
+
379
+ return x
380
+
381
+
382
+ class Cnn14Encoder(nn.Module):
383
+ def __init__(self, sample_rate=32000):
384
+ super().__init__()
385
+ sr_to_fmax = {
386
+ 32000: 14000,
387
+ 16000: 8000
388
+ }
389
+ # Logmel spectrogram extractor
390
+ self.melspec_extractor = transforms.MelSpectrogram(
391
+ sample_rate=sample_rate,
392
+ n_fft=32 * sample_rate // 1000,
393
+ win_length=32 * sample_rate // 1000,
394
+ hop_length=10 * sample_rate // 1000,
395
+ f_min=50,
396
+ f_max=sr_to_fmax[sample_rate],
397
+ n_mels=64,
398
+ norm="slaney",
399
+ mel_scale="slaney"
400
+ )
401
+ self.hop_length = 10 * sample_rate // 1000
402
+ self.db_transform = transforms.AmplitudeToDB()
403
+ # Spec augmenter
404
+ self.spec_augmenter = SpecAugmentation(time_drop_width=64,
405
+ time_stripes_num=2, freq_drop_width=8, freq_stripes_num=2)
406
+
407
+ self.bn0 = nn.BatchNorm2d(64)
408
+
409
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
410
+ self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
411
+ self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
412
+ self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
413
+ self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
414
+ self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
415
+
416
+ self.downsample_ratio = 32
417
+
418
+ self.fc1 = nn.Linear(2048, 2048, bias=True)
419
+
420
+ self.init_weight()
421
+
422
+ def init_weight(self):
423
+ init_bn(self.bn0)
424
+ init_layer(self.fc1)
425
+
426
+ def load_pretrained(self, pretrained):
427
+ checkpoint = torch.load(pretrained, map_location="cpu")
428
+
429
+ if "model" in checkpoint:
430
+ state_keys = checkpoint["model"].keys()
431
+ backbone = False
432
+ for key in state_keys:
433
+ if key.startswith("backbone."):
434
+ backbone = True
435
+ break
436
+
437
+ if backbone: # COLA
438
+ state_dict = {}
439
+ for key, value in checkpoint["model"].items():
440
+ if key.startswith("backbone."):
441
+ model_key = key.replace("backbone.", "")
442
+ state_dict[model_key] = value
443
+ else: # PANNs
444
+ state_dict = checkpoint["model"]
445
+ elif "state_dict" in checkpoint: # CLAP
446
+ state_dict = checkpoint["state_dict"]
447
+ state_dict_keys = list(filter(
448
+ lambda x: "audio_encoder" in x, state_dict.keys()))
449
+ state_dict = {
450
+ key.replace('audio_encoder.', ''): state_dict[key]
451
+ for key in state_dict_keys
452
+ }
453
+ else:
454
+ raise Exception("Unkown checkpoint format")
455
+
456
+ model_dict = self.state_dict()
457
+ pretrained_dict = {
458
+ k: v for k, v in state_dict.items() if (k in model_dict) and (
459
+ model_dict[k].shape == v.shape)
460
+ }
461
+ model_dict.update(pretrained_dict)
462
+ self.load_state_dict(model_dict, strict=True)
463
+
464
+ def forward(self, input_dict):
465
+ """
466
+ Input: (batch_size, n_samples)"""
467
+ waveform = input_dict["wav"]
468
+ wave_length = input_dict["wav_len"]
469
+ specaug = input_dict["specaug"]
470
+ x = self.melspec_extractor(waveform)
471
+ x = self.db_transform(x) # (batch_size, mel_bins, time_steps)
472
+ x = x.transpose(1, 2)
473
+ x = x.unsqueeze(1) # (batch_size, 1, time_steps, mel_bins)
474
+
475
+ # SpecAugment
476
+ if self.training and specaug:
477
+ x = self.spec_augmenter(x)
478
+
479
+ x = x.transpose(1, 3)
480
+ x = self.bn0(x)
481
+ x = x.transpose(1, 3)
482
+
483
+ x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
484
+ x = F.dropout(x, p=0.2, training=self.training)
485
+ x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
486
+ x = F.dropout(x, p=0.2, training=self.training)
487
+ x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
488
+ x = F.dropout(x, p=0.2, training=self.training)
489
+ x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
490
+ x = F.dropout(x, p=0.2, training=self.training)
491
+ x = self.conv_block5(x, pool_size=(2, 2), pool_type='avg')
492
+ x = F.dropout(x, p=0.2, training=self.training)
493
+ x = self.conv_block6(x, pool_size=(1, 1), pool_type='avg')
494
+ x = F.dropout(x, p=0.2, training=self.training)
495
+ x = torch.mean(x, dim=3)
496
+ attn_emb = x.transpose(1, 2)
497
+
498
+ wave_length = torch.as_tensor(wave_length)
499
+ feat_length = torch.div(wave_length, self.hop_length,
500
+ rounding_mode="floor") + 1
501
+ feat_length = torch.div(feat_length, self.downsample_ratio,
502
+ rounding_mode="floor")
503
+ x_max = max_with_lens(attn_emb, feat_length)
504
+ x_mean = mean_with_lens(attn_emb, feat_length)
505
+ x = x_max + x_mean
506
+ x = F.dropout(x, p=0.5, training=self.training)
507
+ x = F.relu_(self.fc1(x))
508
+ fc_emb = F.dropout(x, p=0.5, training=self.training)
509
+
510
+ output_dict = {
511
+ 'fc_emb': fc_emb,
512
+ 'attn_emb': attn_emb,
513
+ 'attn_emb_len': feat_length
514
+ }
515
+
516
+ return output_dict
517
+
518
+
519
+ class RnnEncoder(BaseEncoder):
520
+
521
+ def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim,
522
+ pooling="mean", **kwargs):
523
+ super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
524
+ self.pooling = pooling
525
+ self.hidden_size = kwargs.get('hidden_size', 512)
526
+ self.bidirectional = kwargs.get('bidirectional', False)
527
+ self.num_layers = kwargs.get('num_layers', 1)
528
+ self.dropout = kwargs.get('dropout', 0.2)
529
+ self.rnn_type = kwargs.get('rnn_type', "GRU")
530
+ self.in_bn = kwargs.get('in_bn', False)
531
+ self.embed_dim = self.hidden_size * (self.bidirectional + 1)
532
+ self.network = getattr(nn, self.rnn_type)(
533
+ attn_feat_dim,
534
+ self.hidden_size,
535
+ num_layers=self.num_layers,
536
+ bidirectional=self.bidirectional,
537
+ dropout=self.dropout,
538
+ batch_first=True)
539
+ if self.in_bn:
540
+ self.bn = nn.BatchNorm1d(self.embed_dim)
541
+ self.apply(init)
542
+
543
+ def forward(self, input_dict):
544
+ x = input_dict["attn"]
545
+ lens = input_dict["attn_len"]
546
+ lens = torch.as_tensor(lens)
547
+ # x: [N, T, E]
548
+ if self.in_bn:
549
+ x = pack_wrapper(self.bn, x, lens)
550
+ out = pack_wrapper(self.network, x, lens)
551
+ # out: [N, T, hidden]
552
+ attn_emb = out
553
+ fc_emb = embedding_pooling(out, lens, self.pooling)
554
+ return {
555
+ "attn_emb": attn_emb,
556
+ "fc_emb": fc_emb,
557
+ "attn_emb_len": lens
558
+ }
559
+
560
+
561
+ class Cnn14RnnEncoder(nn.Module):
562
+ def __init__(self, sample_rate=32000, pretrained=None,
563
+ freeze_cnn=False, freeze_cnn_bn=False,
564
+ pooling="mean", **kwargs):
565
+ super().__init__()
566
+ self.cnn = Cnn14Encoder(sample_rate)
567
+ self.rnn = RnnEncoder(64, 2048, 2048, pooling, **kwargs)
568
+ if pretrained is not None:
569
+ self.cnn.load_pretrained(pretrained)
570
+ if freeze_cnn:
571
+ assert pretrained is not None, "cnn is not pretrained but frozen"
572
+ for param in self.cnn.parameters():
573
+ param.requires_grad = False
574
+ self.freeze_cnn_bn = freeze_cnn_bn
575
+
576
+ def train(self, mode):
577
+ super().train(mode=mode)
578
+ if self.freeze_cnn_bn:
579
+ def bn_eval(module):
580
+ class_name = module.__class__.__name__
581
+ if class_name.find("BatchNorm") != -1:
582
+ module.eval()
583
+ self.cnn.apply(bn_eval)
584
+ return self
585
+
586
+ def forward(self, input_dict):
587
+ output_dict = self.cnn(input_dict)
588
+ output_dict["attn"] = output_dict["attn_emb"]
589
+ output_dict["attn_len"] = output_dict["attn_emb_len"]
590
+ del output_dict["attn_emb"], output_dict["attn_emb_len"]
591
+ output_dict = self.rnn(output_dict)
592
+ return output_dict
593
+
594
+
595
+ class TransformerEncoder(BaseEncoder):
596
+
597
+ def __init__(self, spec_dim, fc_feat_dim, attn_feat_dim, d_model, **kwargs):
598
+ super().__init__(spec_dim, fc_feat_dim, attn_feat_dim)
599
+ self.d_model = d_model
600
+ dropout = kwargs.get("dropout", 0.2)
601
+ self.nhead = kwargs.get("nhead", self.d_model // 64)
602
+ self.nlayers = kwargs.get("nlayers", 2)
603
+ self.dim_feedforward = kwargs.get("dim_feedforward", self.d_model * 4)
604
+
605
+ self.attn_proj = nn.Sequential(
606
+ nn.Linear(attn_feat_dim, self.d_model),
607
+ nn.ReLU(),
608
+ nn.Dropout(dropout),
609
+ nn.LayerNorm(self.d_model)
610
+ )
611
+ layer = nn.TransformerEncoderLayer(d_model=self.d_model,
612
+ nhead=self.nhead,
613
+ dim_feedforward=self.dim_feedforward,
614
+ dropout=dropout)
615
+ self.model = nn.TransformerEncoder(layer, self.nlayers)
616
+ self.cls_token = nn.Parameter(torch.zeros(d_model))
617
+ self.init_params()
618
+
619
+ def init_params(self):
620
+ for p in self.parameters():
621
+ if p.dim() > 1:
622
+ nn.init.xavier_uniform_(p)
623
+
624
+ def forward(self, input_dict):
625
+ attn_feat = input_dict["attn"]
626
+ attn_feat_len = input_dict["attn_len"]
627
+ attn_feat_len = torch.as_tensor(attn_feat_len)
628
+
629
+ attn_feat = self.attn_proj(attn_feat) # [bs, T, d_model]
630
+
631
+ cls_emb = self.cls_token.reshape(1, 1, self.d_model).repeat(
632
+ attn_feat.size(0), 1, 1)
633
+ attn_feat = torch.cat((cls_emb, attn_feat), dim=1)
634
+ attn_feat = attn_feat.transpose(0, 1)
635
+
636
+ attn_feat_len += 1
637
+ src_key_padding_mask = ~generate_length_mask(
638
+ attn_feat_len, attn_feat.size(0)).to(attn_feat.device)
639
+ output = self.model(attn_feat, src_key_padding_mask=src_key_padding_mask)
640
+
641
+ attn_emb = output.transpose(0, 1)
642
+ fc_emb = attn_emb[:, 0]
643
+ return {
644
+ "attn_emb": attn_emb,
645
+ "fc_emb": fc_emb,
646
+ "attn_emb_len": attn_feat_len
647
+ }
648
+
649
+
650
+ class Cnn14TransformerEncoder(nn.Module):
651
+ def __init__(self, sample_rate=32000, pretrained=None,
652
+ freeze_cnn=False, freeze_cnn_bn=False,
653
+ d_model="mean", **kwargs):
654
+ super().__init__()
655
+ self.cnn = Cnn14Encoder(sample_rate)
656
+ self.trm = TransformerEncoder(64, 2048, 2048, d_model, **kwargs)
657
+ if pretrained is not None:
658
+ self.cnn.load_pretrained(pretrained)
659
+ if freeze_cnn:
660
+ assert pretrained is not None, "cnn is not pretrained but frozen"
661
+ for param in self.cnn.parameters():
662
+ param.requires_grad = False
663
+ self.freeze_cnn_bn = freeze_cnn_bn
664
+
665
+ def train(self, mode):
666
+ super().train(mode=mode)
667
+ if self.freeze_cnn_bn:
668
+ def bn_eval(module):
669
+ class_name = module.__class__.__name__
670
+ if class_name.find("BatchNorm") != -1:
671
+ module.eval()
672
+ self.cnn.apply(bn_eval)
673
+ return self
674
+
675
+ def forward(self, input_dict):
676
+ output_dict = self.cnn(input_dict)
677
+ output_dict["attn"] = output_dict["attn_emb"]
678
+ output_dict["attn_len"] = output_dict["attn_emb_len"]
679
+ del output_dict["attn_emb"], output_dict["attn_emb_len"]
680
+ output_dict = self.trm(output_dict)
681
+ return output_dict
682
+
683
+
684
+
685
+
686
+
audio_to_text/captioning/models/transformer_model.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import random
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .base_model import CaptionModel
7
+ from .utils import repeat_tensor
8
+ import audio_to_text.captioning.models.decoder
9
+
10
+
11
+ class TransformerModel(CaptionModel):
12
+
13
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
14
+ if not hasattr(self, "compatible_decoders"):
15
+ self.compatible_decoders = (
16
+ audio_to_text.captioning.models.decoder.TransformerDecoder,
17
+ )
18
+ super().__init__(encoder, decoder, **kwargs)
19
+
20
+ def seq_forward(self, input_dict):
21
+ cap = input_dict["cap"]
22
+ cap_padding_mask = (cap == self.pad_idx).to(cap.device)
23
+ cap_padding_mask = cap_padding_mask[:, :-1]
24
+ output = self.decoder(
25
+ {
26
+ "word": cap[:, :-1],
27
+ "attn_emb": input_dict["attn_emb"],
28
+ "attn_emb_len": input_dict["attn_emb_len"],
29
+ "cap_padding_mask": cap_padding_mask
30
+ }
31
+ )
32
+ return output
33
+
34
+ def prepare_decoder_input(self, input_dict, output):
35
+ decoder_input = {
36
+ "attn_emb": input_dict["attn_emb"],
37
+ "attn_emb_len": input_dict["attn_emb_len"]
38
+ }
39
+ t = input_dict["t"]
40
+
41
+ ###############
42
+ # determine input word
43
+ ################
44
+ if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
45
+ word = input_dict["cap"][:, :t+1]
46
+ else:
47
+ start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
48
+ if t == 0:
49
+ word = start_word
50
+ else:
51
+ word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
52
+ # word: [N, T]
53
+ decoder_input["word"] = word
54
+
55
+ cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
56
+ decoder_input["cap_padding_mask"] = cap_padding_mask
57
+ return decoder_input
58
+
59
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
60
+ decoder_input = {}
61
+ t = input_dict["t"]
62
+ i = input_dict["sample_idx"]
63
+ beam_size = input_dict["beam_size"]
64
+ ###############
65
+ # prepare attn embeds
66
+ ################
67
+ if t == 0:
68
+ attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
69
+ attn_emb_len = repeat_tensor(input_dict["attn_emb_len"][i], beam_size)
70
+ output_i["attn_emb"] = attn_emb
71
+ output_i["attn_emb_len"] = attn_emb_len
72
+ decoder_input["attn_emb"] = output_i["attn_emb"]
73
+ decoder_input["attn_emb_len"] = output_i["attn_emb_len"]
74
+ ###############
75
+ # determine input word
76
+ ################
77
+ start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
78
+ if t == 0:
79
+ word = start_word
80
+ else:
81
+ word = torch.cat((start_word, output_i["seq"]), dim=-1)
82
+ decoder_input["word"] = word
83
+ cap_padding_mask = (word == self.pad_idx).to(input_dict["attn_emb"].device)
84
+ decoder_input["cap_padding_mask"] = cap_padding_mask
85
+
86
+ return decoder_input
87
+
88
+
89
+ class M2TransformerModel(CaptionModel):
90
+
91
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
92
+ if not hasattr(self, "compatible_decoders"):
93
+ self.compatible_decoders = (
94
+ captioning.models.decoder.M2TransformerDecoder,
95
+ )
96
+ super().__init__(encoder, decoder, **kwargs)
97
+ self.check_encoder_compatibility()
98
+
99
+ def check_encoder_compatibility(self):
100
+ assert isinstance(self.encoder, captioning.models.encoder.M2TransformerEncoder), \
101
+ f"only M2TransformerModel is compatible with {self.__class__.__name__}"
102
+
103
+
104
+ def seq_forward(self, input_dict):
105
+ cap = input_dict["cap"]
106
+ output = self.decoder(
107
+ {
108
+ "word": cap[:, :-1],
109
+ "attn_emb": input_dict["attn_emb"],
110
+ "attn_emb_mask": input_dict["attn_emb_mask"],
111
+ }
112
+ )
113
+ return output
114
+
115
+ def prepare_decoder_input(self, input_dict, output):
116
+ decoder_input = {
117
+ "attn_emb": input_dict["attn_emb"],
118
+ "attn_emb_mask": input_dict["attn_emb_mask"]
119
+ }
120
+ t = input_dict["t"]
121
+
122
+ ###############
123
+ # determine input word
124
+ ################
125
+ if input_dict["mode"] == "train" and random.random() < input_dict["ss_ratio"]: # training, scheduled sampling
126
+ word = input_dict["cap"][:, :t+1]
127
+ else:
128
+ start_word = torch.tensor([self.start_idx,] * input_dict["attn_emb"].size(0)).unsqueeze(1).long()
129
+ if t == 0:
130
+ word = start_word
131
+ else:
132
+ word = torch.cat((start_word, output["seq"][:, :t]), dim=-1)
133
+ # word: [N, T]
134
+ decoder_input["word"] = word
135
+
136
+ return decoder_input
137
+
138
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
139
+ decoder_input = {}
140
+ t = input_dict["t"]
141
+ i = input_dict["sample_idx"]
142
+ beam_size = input_dict["beam_size"]
143
+ ###############
144
+ # prepare attn embeds
145
+ ################
146
+ if t == 0:
147
+ attn_emb = repeat_tensor(input_dict["attn_emb"][i], beam_size)
148
+ attn_emb_mask = repeat_tensor(input_dict["attn_emb_mask"][i], beam_size)
149
+ output_i["attn_emb"] = attn_emb
150
+ output_i["attn_emb_mask"] = attn_emb_mask
151
+ decoder_input["attn_emb"] = output_i["attn_emb"]
152
+ decoder_input["attn_emb_mask"] = output_i["attn_emb_mask"]
153
+ ###############
154
+ # determine input word
155
+ ################
156
+ start_word = torch.tensor([self.start_idx,] * beam_size).unsqueeze(1).long()
157
+ if t == 0:
158
+ word = start_word
159
+ else:
160
+ word = torch.cat((start_word, output_i["seq"]), dim=-1)
161
+ decoder_input["word"] = word
162
+
163
+ return decoder_input
164
+
165
+
166
+ class EventEncoder(nn.Module):
167
+ """
168
+ Encode the Label information in AudioCaps and AudioSet
169
+ """
170
+ def __init__(self, emb_dim, vocab_size=527):
171
+ super(EventEncoder, self).__init__()
172
+ self.label_embedding = nn.Parameter(
173
+ torch.randn((vocab_size, emb_dim)), requires_grad=True)
174
+
175
+ def forward(self, word_idxs):
176
+ indices = word_idxs / word_idxs.sum(dim=1, keepdim=True)
177
+ embeddings = indices @ self.label_embedding
178
+ return embeddings
179
+
180
+
181
+ class EventCondTransformerModel(TransformerModel):
182
+
183
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
184
+ if not hasattr(self, "compatible_decoders"):
185
+ self.compatible_decoders = (
186
+ captioning.models.decoder.EventTransformerDecoder,
187
+ )
188
+ super().__init__(encoder, decoder, **kwargs)
189
+ self.label_encoder = EventEncoder(decoder.emb_dim, 527)
190
+ self.train_forward_keys += ["events"]
191
+ self.inference_forward_keys += ["events"]
192
+
193
+ # def seq_forward(self, input_dict):
194
+ # cap = input_dict["cap"]
195
+ # cap_padding_mask = (cap == self.pad_idx).to(cap.device)
196
+ # cap_padding_mask = cap_padding_mask[:, :-1]
197
+ # output = self.decoder(
198
+ # {
199
+ # "word": cap[:, :-1],
200
+ # "attn_emb": input_dict["attn_emb"],
201
+ # "attn_emb_len": input_dict["attn_emb_len"],
202
+ # "cap_padding_mask": cap_padding_mask
203
+ # }
204
+ # )
205
+ # return output
206
+
207
+ def prepare_decoder_input(self, input_dict, output):
208
+ decoder_input = super().prepare_decoder_input(input_dict, output)
209
+ decoder_input["events"] = self.label_encoder(input_dict["events"])
210
+ return decoder_input
211
+
212
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
213
+ decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
214
+ t = input_dict["t"]
215
+ i = input_dict["sample_idx"]
216
+ beam_size = input_dict["beam_size"]
217
+ if t == 0:
218
+ output_i["events"] = repeat_tensor(self.label_encoder(input_dict["events"])[i], beam_size)
219
+ decoder_input["events"] = output_i["events"]
220
+ return decoder_input
221
+
222
+
223
+ class KeywordCondTransformerModel(TransformerModel):
224
+
225
+ def __init__(self, encoder: nn.Module, decoder: nn.Module, **kwargs):
226
+ if not hasattr(self, "compatible_decoders"):
227
+ self.compatible_decoders = (
228
+ captioning.models.decoder.KeywordProbTransformerDecoder,
229
+ )
230
+ super().__init__(encoder, decoder, **kwargs)
231
+ self.train_forward_keys += ["keyword"]
232
+ self.inference_forward_keys += ["keyword"]
233
+
234
+ def seq_forward(self, input_dict):
235
+ cap = input_dict["cap"]
236
+ cap_padding_mask = (cap == self.pad_idx).to(cap.device)
237
+ cap_padding_mask = cap_padding_mask[:, :-1]
238
+ keyword = input_dict["keyword"]
239
+ output = self.decoder(
240
+ {
241
+ "word": cap[:, :-1],
242
+ "attn_emb": input_dict["attn_emb"],
243
+ "attn_emb_len": input_dict["attn_emb_len"],
244
+ "keyword": keyword,
245
+ "cap_padding_mask": cap_padding_mask
246
+ }
247
+ )
248
+ return output
249
+
250
+ def prepare_decoder_input(self, input_dict, output):
251
+ decoder_input = super().prepare_decoder_input(input_dict, output)
252
+ decoder_input["keyword"] = input_dict["keyword"]
253
+ return decoder_input
254
+
255
+ def prepare_beamsearch_decoder_input(self, input_dict, output_i):
256
+ decoder_input = super().prepare_beamsearch_decoder_input(input_dict, output_i)
257
+ t = input_dict["t"]
258
+ i = input_dict["sample_idx"]
259
+ beam_size = input_dict["beam_size"]
260
+ if t == 0:
261
+ output_i["keyword"] = repeat_tensor(input_dict["keyword"][i],
262
+ beam_size)
263
+ decoder_input["keyword"] = output_i["keyword"]
264
+ return decoder_input
265
+
audio_to_text/captioning/models/utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
8
+
9
+
10
+ def sort_pack_padded_sequence(input, lengths):
11
+ sorted_lengths, indices = torch.sort(lengths, descending=True)
12
+ tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
13
+ inv_ix = indices.clone()
14
+ inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
15
+ return tmp, inv_ix
16
+
17
+ def pad_unsort_packed_sequence(input, inv_ix):
18
+ tmp, _ = pad_packed_sequence(input, batch_first=True)
19
+ tmp = tmp[inv_ix]
20
+ return tmp
21
+
22
+ def pack_wrapper(module, attn_feats, attn_feat_lens):
23
+ packed, inv_ix = sort_pack_padded_sequence(attn_feats, attn_feat_lens)
24
+ if isinstance(module, torch.nn.RNNBase):
25
+ return pad_unsort_packed_sequence(module(packed)[0], inv_ix)
26
+ else:
27
+ return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
28
+
29
+ def generate_length_mask(lens, max_length=None):
30
+ lens = torch.as_tensor(lens)
31
+ N = lens.size(0)
32
+ if max_length is None:
33
+ max_length = max(lens)
34
+ idxs = torch.arange(max_length).repeat(N).view(N, max_length)
35
+ idxs = idxs.to(lens.device)
36
+ mask = (idxs < lens.view(-1, 1))
37
+ return mask
38
+
39
+ def mean_with_lens(features, lens):
40
+ """
41
+ features: [N, T, ...] (assume the second dimension represents length)
42
+ lens: [N,]
43
+ """
44
+ lens = torch.as_tensor(lens)
45
+ if max(lens) != features.size(1):
46
+ max_length = features.size(1)
47
+ mask = generate_length_mask(lens, max_length)
48
+ else:
49
+ mask = generate_length_mask(lens)
50
+ mask = mask.to(features.device) # [N, T]
51
+
52
+ while mask.ndim < features.ndim:
53
+ mask = mask.unsqueeze(-1)
54
+ feature_mean = features * mask
55
+ feature_mean = feature_mean.sum(1)
56
+ while lens.ndim < feature_mean.ndim:
57
+ lens = lens.unsqueeze(1)
58
+ feature_mean = feature_mean / lens.to(features.device)
59
+ # feature_mean = features * mask.unsqueeze(-1)
60
+ # feature_mean = feature_mean.sum(1) / lens.unsqueeze(1).to(features.device)
61
+ return feature_mean
62
+
63
+ def max_with_lens(features, lens):
64
+ """
65
+ features: [N, T, ...] (assume the second dimension represents length)
66
+ lens: [N,]
67
+ """
68
+ lens = torch.as_tensor(lens)
69
+ mask = generate_length_mask(lens).to(features.device) # [N, T]
70
+
71
+ feature_max = features.clone()
72
+ feature_max[~mask] = float("-inf")
73
+ feature_max, _ = feature_max.max(1)
74
+ return feature_max
75
+
76
+ def repeat_tensor(x, n):
77
+ return x.unsqueeze(0).repeat(n, *([1] * len(x.shape)))
78
+
79
+ def init(m, method="kaiming"):
80
+ if isinstance(m, (nn.Conv2d, nn.Conv1d)):
81
+ if method == "kaiming":
82
+ nn.init.kaiming_uniform_(m.weight)
83
+ elif method == "xavier":
84
+ nn.init.xavier_uniform_(m.weight)
85
+ else:
86
+ raise Exception(f"initialization method {method} not supported")
87
+ if m.bias is not None:
88
+ nn.init.constant_(m.bias, 0)
89
+ elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
90
+ nn.init.constant_(m.weight, 1)
91
+ if m.bias is not None:
92
+ nn.init.constant_(m.bias, 0)
93
+ elif isinstance(m, nn.Linear):
94
+ if method == "kaiming":
95
+ nn.init.kaiming_uniform_(m.weight)
96
+ elif method == "xavier":
97
+ nn.init.xavier_uniform_(m.weight)
98
+ else:
99
+ raise Exception(f"initialization method {method} not supported")
100
+ if m.bias is not None:
101
+ nn.init.constant_(m.bias, 0)
102
+ elif isinstance(m, nn.Embedding):
103
+ if method == "kaiming":
104
+ nn.init.kaiming_uniform_(m.weight)
105
+ elif method == "xavier":
106
+ nn.init.xavier_uniform_(m.weight)
107
+ else:
108
+ raise Exception(f"initialization method {method} not supported")
109
+
110
+
111
+
112
+
113
+ class PositionalEncoding(nn.Module):
114
+
115
+ def __init__(self, d_model, dropout=0.1, max_len=100):
116
+ super(PositionalEncoding, self).__init__()
117
+ self.dropout = nn.Dropout(p=dropout)
118
+
119
+ pe = torch.zeros(max_len, d_model)
120
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
121
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * \
122
+ (-math.log(10000.0) / d_model))
123
+ pe[:, 0::2] = torch.sin(position * div_term)
124
+ pe[:, 1::2] = torch.cos(position * div_term)
125
+ pe = pe.unsqueeze(0).transpose(0, 1)
126
+ # self.register_buffer("pe", pe)
127
+ self.register_parameter("pe", nn.Parameter(pe, requires_grad=False))
128
+
129
+ def forward(self, x):
130
+ # x: [T, N, E]
131
+ x = x + self.pe[:x.size(0), :]
132
+ return self.dropout(x)
audio_to_text/captioning/utils/README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Utils
2
+
3
+ Scripts in this directory are used as utility functions.
4
+
5
+ ## BERT Pretrained Embeddings
6
+
7
+ You can load pretrained word embeddings in Google [BERT](https://github.com/google-research/bert#pre-trained-models) instead of training word embeddings from scratch. The scripts in `utils/bert` need a BERT server in the background. We use BERT server from [bert-as-service](https://github.com/hanxiao/bert-as-service).
8
+
9
+ To use bert-as-service, you need to first install the repository. It is recommended that you create a new environment with Tensorflow 1.3 to run BERT server since it is incompatible with Tensorflow 2.x.
10
+
11
+ After successful installation of [bert-as-service](https://github.com/hanxiao/bert-as-service), downloading and running the BERT server needs to execute:
12
+
13
+ ```bash
14
+ bash scripts/prepare_bert_server.sh <path-to-server> <num-workers> zh
15
+ ```
16
+
17
+ By default, server based on BERT base Chinese model is running in the background. You can change to other models by changing corresponding model name and path in `scripts/prepare_bert_server.sh`.
18
+
19
+ To extract BERT word embeddings, you need to execute `utils/bert/create_word_embedding.py`.
audio_to_text/captioning/utils/__init__.py ADDED
File without changes
audio_to_text/captioning/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (174 Bytes). View file
 
audio_to_text/captioning/utils/__pycache__/train_util.cpython-38.pyc ADDED
Binary file (5.75 kB). View file
 
audio_to_text/captioning/utils/bert/create_sent_embedding.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import fire
3
+ import numpy as np
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+
7
+
8
+ class EmbeddingExtractor(object):
9
+
10
+ def extract_sentbert(self, caption_file: str, output: str, dev: bool=True, zh: bool=False):
11
+ from sentence_transformers import SentenceTransformer
12
+ lang2model = {
13
+ "zh": "distiluse-base-multilingual-cased",
14
+ "en": "bert-base-nli-mean-tokens"
15
+ }
16
+ lang = "zh" if zh else "en"
17
+ model = SentenceTransformer(lang2model[lang])
18
+
19
+ self.extract(caption_file, model, output, dev)
20
+
21
+ def extract_originbert(self, caption_file: str, output: str, dev: bool=True, ip="localhost"):
22
+ from bert_serving.client import BertClient
23
+ client = BertClient(ip)
24
+
25
+ self.extract(caption_file, client, output, dev)
26
+
27
+ def extract(self, caption_file: str, model, output, dev: bool):
28
+ caption_df = pd.read_json(caption_file, dtype={"key": str})
29
+ embeddings = {}
30
+
31
+ if dev:
32
+ with tqdm(total=caption_df.shape[0], ascii=True) as pbar:
33
+ for idx, row in caption_df.iterrows():
34
+ caption = row["caption"]
35
+ key = row["key"]
36
+ cap_idx = row["caption_index"]
37
+ embedding = model.encode([caption])
38
+ embedding = np.array(embedding).reshape(-1)
39
+ embeddings[f"{key}_{cap_idx}"] = embedding
40
+ pbar.update()
41
+
42
+ else:
43
+ dump = {}
44
+
45
+ with tqdm(total=caption_df.shape[0], ascii=True) as pbar:
46
+ for idx, row in caption_df.iterrows():
47
+ key = row["key"]
48
+ caption = row["caption"]
49
+ value = np.array(model.encode([caption])).reshape(-1)
50
+
51
+ if key not in embeddings.keys():
52
+ embeddings[key] = [value]
53
+ else:
54
+ embeddings[key].append(value)
55
+
56
+ pbar.update()
57
+
58
+ for key in embeddings:
59
+ dump[key] = np.stack(embeddings[key])
60
+
61
+ embeddings = dump
62
+
63
+ with open(output, "wb") as f:
64
+ pickle.dump(embeddings, f)
65
+
66
+ def extract_sbert(self,
67
+ input_json: str,
68
+ output: str):
69
+ from sentence_transformers import SentenceTransformer
70
+ import json
71
+ import torch
72
+ from h5py import File
73
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
75
+ model = model.to(device)
76
+ model.eval()
77
+
78
+ data = json.load(open(input_json))["audios"]
79
+ with torch.no_grad(), tqdm(total=len(data), ascii=True) as pbar, File(output, "w") as store:
80
+ for sample in data:
81
+ audio_id = sample["audio_id"]
82
+ for cap in sample["captions"]:
83
+ cap_id = cap["cap_id"]
84
+ store[f"{audio_id}_{cap_id}"] = model.encode(cap["caption"])
85
+ pbar.update()
86
+
87
+
88
+ if __name__ == "__main__":
89
+ fire.Fire(EmbeddingExtractor)
audio_to_text/captioning/utils/bert/create_word_embedding.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import sys
4
+ import os
5
+
6
+ from bert_serving.client import BertClient
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import fire
10
+ import torch
11
+
12
+ sys.path.append(os.getcwd())
13
+ from utils.build_vocab import Vocabulary
14
+
15
+ def main(vocab_file: str, output: str, server_hostname: str):
16
+ client = BertClient(ip=server_hostname)
17
+ vocabulary = torch.load(vocab_file)
18
+ vocab_size = len(vocabulary)
19
+
20
+ fake_embedding = client.encode(["test"]).reshape(-1)
21
+ embed_size = fake_embedding.shape[0]
22
+
23
+ print("Encoding words into embeddings with size: ", embed_size)
24
+
25
+ embeddings = np.empty((vocab_size, embed_size))
26
+ for i in tqdm(range(len(embeddings)), ascii=True):
27
+ embeddings[i] = client.encode([vocabulary.idx2word[i]])
28
+ np.save(output, embeddings)
29
+
30
+
31
+ if __name__ == '__main__':
32
+ fire.Fire(main)
33
+
34
+
audio_to_text/captioning/utils/build_vocab.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from tqdm import tqdm
3
+ import logging
4
+ import pickle
5
+ from collections import Counter
6
+ import re
7
+ import fire
8
+
9
+
10
+ class Vocabulary(object):
11
+ """Simple vocabulary wrapper."""
12
+ def __init__(self):
13
+ self.word2idx = {}
14
+ self.idx2word = {}
15
+ self.idx = 0
16
+
17
+ def add_word(self, word):
18
+ if not word in self.word2idx:
19
+ self.word2idx[word] = self.idx
20
+ self.idx2word[self.idx] = word
21
+ self.idx += 1
22
+
23
+ def __call__(self, word):
24
+ if not word in self.word2idx:
25
+ return self.word2idx["<unk>"]
26
+ return self.word2idx[word]
27
+
28
+ def __getitem__(self, word_id):
29
+ return self.idx2word[word_id]
30
+
31
+ def __len__(self):
32
+ return len(self.word2idx)
33
+
34
+
35
+ def build_vocab(input_json: str,
36
+ threshold: int,
37
+ keep_punctuation: bool,
38
+ host_address: str,
39
+ character_level: bool = False,
40
+ zh: bool = True ):
41
+ """Build vocabulary from csv file with a given threshold to drop all counts < threshold
42
+
43
+ Args:
44
+ input_json(string): Preprossessed json file. Structure like this:
45
+ {
46
+ 'audios': [
47
+ {
48
+ 'audio_id': 'xxx',
49
+ 'captions': [
50
+ {
51
+ 'caption': 'xxx',
52
+ 'cap_id': 'xxx'
53
+ }
54
+ ]
55
+ },
56
+ ...
57
+ ]
58
+ }
59
+ threshold (int): Threshold to drop all words with counts < threshold
60
+ keep_punctuation (bool): Includes or excludes punctuation.
61
+
62
+ Returns:
63
+ vocab (Vocab): Object with the processed vocabulary
64
+ """
65
+ data = json.load(open(input_json, "r"))["audios"]
66
+ counter = Counter()
67
+ pretokenized = "tokens" in data[0]["captions"][0]
68
+
69
+ if zh:
70
+ from nltk.parse.corenlp import CoreNLPParser
71
+ from zhon.hanzi import punctuation
72
+ if not pretokenized:
73
+ parser = CoreNLPParser(host_address)
74
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
75
+ for cap_idx in range(len(data[audio_idx]["captions"])):
76
+ if pretokenized:
77
+ tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
78
+ else:
79
+ caption = data[audio_idx]["captions"][cap_idx]["caption"]
80
+ # Remove all punctuations
81
+ if not keep_punctuation:
82
+ caption = re.sub("[{}]".format(punctuation), "", caption)
83
+ if character_level:
84
+ tokens = list(caption)
85
+ else:
86
+ tokens = list(parser.tokenize(caption))
87
+ data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
88
+ counter.update(tokens)
89
+ else:
90
+ if pretokenized:
91
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
92
+ for cap_idx in range(len(data[audio_idx]["captions"])):
93
+ tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
94
+ counter.update(tokens)
95
+ else:
96
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
97
+ captions = {}
98
+ for audio_idx in range(len(data)):
99
+ audio_id = data[audio_idx]["audio_id"]
100
+ captions[audio_id] = []
101
+ for cap_idx in range(len(data[audio_idx]["captions"])):
102
+ caption = data[audio_idx]["captions"][cap_idx]["caption"]
103
+ captions[audio_id].append({
104
+ "audio_id": audio_id,
105
+ "id": cap_idx,
106
+ "caption": caption
107
+ })
108
+ tokenizer = PTBTokenizer()
109
+ captions = tokenizer.tokenize(captions)
110
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
111
+ audio_id = data[audio_idx]["audio_id"]
112
+ for cap_idx in range(len(data[audio_idx]["captions"])):
113
+ tokens = captions[audio_id][cap_idx]
114
+ data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
115
+ counter.update(tokens.split(" "))
116
+
117
+ if not pretokenized:
118
+ json.dump({ "audios": data }, open(input_json, "w"), indent=4, ensure_ascii=not zh)
119
+ words = [word for word, cnt in counter.items() if cnt >= threshold]
120
+
121
+ # Create a vocab wrapper and add some special tokens.
122
+ vocab = Vocabulary()
123
+ vocab.add_word("<pad>")
124
+ vocab.add_word("<start>")
125
+ vocab.add_word("<end>")
126
+ vocab.add_word("<unk>")
127
+
128
+ # Add the words to the vocabulary.
129
+ for word in words:
130
+ vocab.add_word(word)
131
+ return vocab
132
+
133
+
134
+ def process(input_json: str,
135
+ output_file: str,
136
+ threshold: int = 1,
137
+ keep_punctuation: bool = False,
138
+ character_level: bool = False,
139
+ host_address: str = "http://localhost:9000",
140
+ zh: bool = False):
141
+ logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
142
+ logging.basicConfig(level=logging.INFO, format=logfmt)
143
+ logging.info("Build Vocab")
144
+ vocabulary = build_vocab(
145
+ input_json=input_json, threshold=threshold, keep_punctuation=keep_punctuation,
146
+ host_address=host_address, character_level=character_level, zh=zh)
147
+ pickle.dump(vocabulary, open(output_file, "wb"))
148
+ logging.info("Total vocabulary size: {}".format(len(vocabulary)))
149
+ logging.info("Saved vocab to '{}'".format(output_file))
150
+
151
+
152
+ if __name__ == '__main__':
153
+ fire.Fire(process)
audio_to_text/captioning/utils/build_vocab_ltp.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from tqdm import tqdm
3
+ import logging
4
+ import pickle
5
+ from collections import Counter
6
+ import re
7
+ import fire
8
+
9
+ class Vocabulary(object):
10
+ """Simple vocabulary wrapper."""
11
+ def __init__(self):
12
+ self.word2idx = {}
13
+ self.idx2word = {}
14
+ self.idx = 0
15
+
16
+ def add_word(self, word):
17
+ if not word in self.word2idx:
18
+ self.word2idx[word] = self.idx
19
+ self.idx2word[self.idx] = word
20
+ self.idx += 1
21
+
22
+ def __call__(self, word):
23
+ if not word in self.word2idx:
24
+ return self.word2idx["<unk>"]
25
+ return self.word2idx[word]
26
+
27
+ def __len__(self):
28
+ return len(self.word2idx)
29
+
30
+ def build_vocab(input_json: str,
31
+ output_json: str,
32
+ threshold: int,
33
+ keep_punctuation: bool,
34
+ character_level: bool = False,
35
+ zh: bool = True ):
36
+ """Build vocabulary from csv file with a given threshold to drop all counts < threshold
37
+
38
+ Args:
39
+ input_json(string): Preprossessed json file. Structure like this:
40
+ {
41
+ 'audios': [
42
+ {
43
+ 'audio_id': 'xxx',
44
+ 'captions': [
45
+ {
46
+ 'caption': 'xxx',
47
+ 'cap_id': 'xxx'
48
+ }
49
+ ]
50
+ },
51
+ ...
52
+ ]
53
+ }
54
+ threshold (int): Threshold to drop all words with counts < threshold
55
+ keep_punctuation (bool): Includes or excludes punctuation.
56
+
57
+ Returns:
58
+ vocab (Vocab): Object with the processed vocabulary
59
+ """
60
+ data = json.load(open(input_json, "r"))["audios"]
61
+ counter = Counter()
62
+ pretokenized = "tokens" in data[0]["captions"][0]
63
+
64
+ if zh:
65
+ from ltp import LTP
66
+ from zhon.hanzi import punctuation
67
+ if not pretokenized:
68
+ parser = LTP("base")
69
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
70
+ for cap_idx in range(len(data[audio_idx]["captions"])):
71
+ if pretokenized:
72
+ tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
73
+ else:
74
+ caption = data[audio_idx]["captions"][cap_idx]["caption"]
75
+ if character_level:
76
+ tokens = list(caption)
77
+ else:
78
+ tokens, _ = parser.seg([caption])
79
+ tokens = tokens[0]
80
+ # Remove all punctuations
81
+ if not keep_punctuation:
82
+ tokens = [token for token in tokens if token not in punctuation]
83
+ data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
84
+ counter.update(tokens)
85
+ else:
86
+ if pretokenized:
87
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
88
+ for cap_idx in range(len(data[audio_idx]["captions"])):
89
+ tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
90
+ counter.update(tokens)
91
+ else:
92
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
93
+ captions = {}
94
+ for audio_idx in range(len(data)):
95
+ audio_id = data[audio_idx]["audio_id"]
96
+ captions[audio_id] = []
97
+ for cap_idx in range(len(data[audio_idx]["captions"])):
98
+ caption = data[audio_idx]["captions"][cap_idx]["caption"]
99
+ captions[audio_id].append({
100
+ "audio_id": audio_id,
101
+ "id": cap_idx,
102
+ "caption": caption
103
+ })
104
+ tokenizer = PTBTokenizer()
105
+ captions = tokenizer.tokenize(captions)
106
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
107
+ audio_id = data[audio_idx]["audio_id"]
108
+ for cap_idx in range(len(data[audio_idx]["captions"])):
109
+ tokens = captions[audio_id][cap_idx]
110
+ data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
111
+ counter.update(tokens.split(" "))
112
+
113
+ if not pretokenized:
114
+ if output_json is None:
115
+ output_json = input_json
116
+ json.dump({ "audios": data }, open(output_json, "w"), indent=4, ensure_ascii=not zh)
117
+ words = [word for word, cnt in counter.items() if cnt >= threshold]
118
+
119
+ # Create a vocab wrapper and add some special tokens.
120
+ vocab = Vocabulary()
121
+ vocab.add_word("<pad>")
122
+ vocab.add_word("<start>")
123
+ vocab.add_word("<end>")
124
+ vocab.add_word("<unk>")
125
+
126
+ # Add the words to the vocabulary.
127
+ for word in words:
128
+ vocab.add_word(word)
129
+ return vocab
130
+
131
+ def process(input_json: str,
132
+ output_file: str,
133
+ output_json: str = None,
134
+ threshold: int = 1,
135
+ keep_punctuation: bool = False,
136
+ character_level: bool = False,
137
+ zh: bool = True):
138
+ logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
139
+ logging.basicConfig(level=logging.INFO, format=logfmt)
140
+ logging.info("Build Vocab")
141
+ vocabulary = build_vocab(
142
+ input_json=input_json, output_json=output_json, threshold=threshold,
143
+ keep_punctuation=keep_punctuation, character_level=character_level, zh=zh)
144
+ pickle.dump(vocabulary, open(output_file, "wb"))
145
+ logging.info("Total vocabulary size: {}".format(len(vocabulary)))
146
+ logging.info("Saved vocab to '{}'".format(output_file))
147
+
148
+
149
+ if __name__ == '__main__':
150
+ fire.Fire(process)
audio_to_text/captioning/utils/build_vocab_spacy.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from tqdm import tqdm
3
+ import logging
4
+ import pickle
5
+ from collections import Counter
6
+ import re
7
+ import fire
8
+
9
+ class Vocabulary(object):
10
+ """Simple vocabulary wrapper."""
11
+ def __init__(self):
12
+ self.word2idx = {}
13
+ self.idx2word = {}
14
+ self.idx = 0
15
+
16
+ def add_word(self, word):
17
+ if not word in self.word2idx:
18
+ self.word2idx[word] = self.idx
19
+ self.idx2word[self.idx] = word
20
+ self.idx += 1
21
+
22
+ def __call__(self, word):
23
+ if not word in self.word2idx:
24
+ return self.word2idx["<unk>"]
25
+ return self.word2idx[word]
26
+
27
+ def __len__(self):
28
+ return len(self.word2idx)
29
+
30
+
31
+ def build_vocab(input_json: str,
32
+ output_json: str,
33
+ threshold: int,
34
+ keep_punctuation: bool,
35
+ host_address: str,
36
+ character_level: bool = False,
37
+ retokenize: bool = True,
38
+ zh: bool = True ):
39
+ """Build vocabulary from csv file with a given threshold to drop all counts < threshold
40
+
41
+ Args:
42
+ input_json(string): Preprossessed json file. Structure like this:
43
+ {
44
+ 'audios': [
45
+ {
46
+ 'audio_id': 'xxx',
47
+ 'captions': [
48
+ {
49
+ 'caption': 'xxx',
50
+ 'cap_id': 'xxx'
51
+ }
52
+ ]
53
+ },
54
+ ...
55
+ ]
56
+ }
57
+ threshold (int): Threshold to drop all words with counts < threshold
58
+ keep_punctuation (bool): Includes or excludes punctuation.
59
+
60
+ Returns:
61
+ vocab (Vocab): Object with the processed vocabulary
62
+ """
63
+ data = json.load(open(input_json, "r"))["audios"]
64
+ counter = Counter()
65
+ if retokenize:
66
+ pretokenized = False
67
+ else:
68
+ pretokenized = "tokens" in data[0]["captions"][0]
69
+
70
+ if zh:
71
+ from nltk.parse.corenlp import CoreNLPParser
72
+ from zhon.hanzi import punctuation
73
+ if not pretokenized:
74
+ parser = CoreNLPParser(host_address)
75
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
76
+ for cap_idx in range(len(data[audio_idx]["captions"])):
77
+ if pretokenized:
78
+ tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
79
+ else:
80
+ caption = data[audio_idx]["captions"][cap_idx]["caption"]
81
+ # Remove all punctuations
82
+ if not keep_punctuation:
83
+ caption = re.sub("[{}]".format(punctuation), "", caption)
84
+ if character_level:
85
+ tokens = list(caption)
86
+ else:
87
+ tokens = list(parser.tokenize(caption))
88
+ data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
89
+ counter.update(tokens)
90
+ else:
91
+ if pretokenized:
92
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
93
+ for cap_idx in range(len(data[audio_idx]["captions"])):
94
+ tokens = data[audio_idx]["captions"][cap_idx]["tokens"].split()
95
+ counter.update(tokens)
96
+ else:
97
+ import spacy
98
+ tokenizer = spacy.load("en_core_web_sm", disable=["parser", "ner"])
99
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
100
+ captions = data[audio_idx]["captions"]
101
+ for cap_idx in range(len(captions)):
102
+ caption = captions[cap_idx]["caption"]
103
+ doc = tokenizer(caption)
104
+ tokens = " ".join([str(token).lower() for token in doc])
105
+ data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
106
+ counter.update(tokens.split(" "))
107
+
108
+ if not pretokenized:
109
+ if output_json is None:
110
+ json.dump({ "audios": data }, open(input_json, "w"),
111
+ indent=4, ensure_ascii=not zh)
112
+ else:
113
+ json.dump({ "audios": data }, open(output_json, "w"),
114
+ indent=4, ensure_ascii=not zh)
115
+
116
+ words = [word for word, cnt in counter.items() if cnt >= threshold]
117
+
118
+ # Create a vocab wrapper and add some special tokens.
119
+ vocab = Vocabulary()
120
+ vocab.add_word("<pad>")
121
+ vocab.add_word("<start>")
122
+ vocab.add_word("<end>")
123
+ vocab.add_word("<unk>")
124
+
125
+ # Add the words to the vocabulary.
126
+ for word in words:
127
+ vocab.add_word(word)
128
+ return vocab
129
+
130
+ def process(input_json: str,
131
+ output_file: str,
132
+ output_json: str = None,
133
+ threshold: int = 1,
134
+ keep_punctuation: bool = False,
135
+ character_level: bool = False,
136
+ retokenize: bool = False,
137
+ host_address: str = "http://localhost:9000",
138
+ zh: bool = True):
139
+ logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
140
+ logging.basicConfig(level=logging.INFO, format=logfmt)
141
+ logging.info("Build Vocab")
142
+ vocabulary = build_vocab(
143
+ input_json=input_json, output_json=output_json, threshold=threshold,
144
+ keep_punctuation=keep_punctuation, host_address=host_address,
145
+ character_level=character_level, retokenize=retokenize, zh=zh)
146
+ pickle.dump(vocabulary, open(output_file, "wb"))
147
+ logging.info("Total vocabulary size: {}".format(len(vocabulary)))
148
+ logging.info("Saved vocab to '{}'".format(output_file))
149
+
150
+
151
+ if __name__ == '__main__':
152
+ fire.Fire(process)
audio_to_text/captioning/utils/eval_round_robin.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+
4
+ import numpy as np
5
+ import fire
6
+
7
+
8
+ def evaluate_annotation(key2refs, scorer):
9
+ if scorer.method() == "Bleu":
10
+ scores = np.array([ 0.0 for n in range(4) ])
11
+ else:
12
+ scores = 0
13
+ num_cap_per_audio = len(next(iter(key2refs.values())))
14
+
15
+ for i in range(num_cap_per_audio):
16
+ if i > 0:
17
+ for key in key2refs:
18
+ key2refs[key].insert(0, res[key][0])
19
+ res = { key: [refs.pop(),] for key, refs in key2refs.items() }
20
+ score, _ = scorer.compute_score(key2refs, res)
21
+
22
+ if scorer.method() == "Bleu":
23
+ scores += np.array(score)
24
+ else:
25
+ scores += score
26
+
27
+ score = scores / num_cap_per_audio
28
+ return score
29
+
30
+ def evaluate_prediction(key2pred, key2refs, scorer):
31
+ if scorer.method() == "Bleu":
32
+ scores = np.array([ 0.0 for n in range(4) ])
33
+ else:
34
+ scores = 0
35
+ num_cap_per_audio = len(next(iter(key2refs.values())))
36
+
37
+ for i in range(num_cap_per_audio):
38
+ key2refs_i = {}
39
+ for key, refs in key2refs.items():
40
+ key2refs_i[key] = refs[:i] + refs[i+1:]
41
+ score, _ = scorer.compute_score(key2refs_i, key2pred)
42
+
43
+ if scorer.method() == "Bleu":
44
+ scores += np.array(score)
45
+ else:
46
+ scores += score
47
+
48
+ score = scores / num_cap_per_audio
49
+ return score
50
+
51
+
52
+ class Evaluator(object):
53
+
54
+ def eval_annotation(self, annotation, output):
55
+ captions = json.load(open(annotation, "r"))["audios"]
56
+
57
+ key2refs = {}
58
+ for audio_idx in range(len(captions)):
59
+ audio_id = captions[audio_idx]["audio_id"]
60
+ key2refs[audio_id] = []
61
+ for caption in captions[audio_idx]["captions"]:
62
+ key2refs[audio_id].append(caption["caption"])
63
+
64
+ from fense.fense import Fense
65
+ scores = {}
66
+ scorer = Fense()
67
+ scores[scorer.method()] = evaluate_annotation(copy.deepcopy(key2refs), scorer)
68
+
69
+ refs4eval = {}
70
+ for key, refs in key2refs.items():
71
+ refs4eval[key] = []
72
+ for idx, ref in enumerate(refs):
73
+ refs4eval[key].append({
74
+ "audio_id": key,
75
+ "id": idx,
76
+ "caption": ref
77
+ })
78
+
79
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
80
+
81
+ tokenizer = PTBTokenizer()
82
+ key2refs = tokenizer.tokenize(refs4eval)
83
+
84
+
85
+ from pycocoevalcap.bleu.bleu import Bleu
86
+ from pycocoevalcap.cider.cider import Cider
87
+ from pycocoevalcap.rouge.rouge import Rouge
88
+ from pycocoevalcap.meteor.meteor import Meteor
89
+ from pycocoevalcap.spice.spice import Spice
90
+
91
+
92
+ scorers = [Bleu(), Rouge(), Cider(), Meteor(), Spice()]
93
+ for scorer in scorers:
94
+ scores[scorer.method()] = evaluate_annotation(copy.deepcopy(key2refs), scorer)
95
+
96
+ spider = 0
97
+ with open(output, "w") as f:
98
+ for name, score in scores.items():
99
+ if name == "Bleu":
100
+ for n in range(4):
101
+ f.write("Bleu-{}: {:6.3f}\n".format(n + 1, score[n]))
102
+ else:
103
+ f.write("{}: {:6.3f}\n".format(name, score))
104
+ if name in ["CIDEr", "SPICE"]:
105
+ spider += score
106
+ f.write("SPIDEr: {:6.3f}\n".format(spider / 2))
107
+
108
+ def eval_prediction(self, prediction, annotation, output):
109
+ ref_captions = json.load(open(annotation, "r"))["audios"]
110
+
111
+ key2refs = {}
112
+ for audio_idx in range(len(ref_captions)):
113
+ audio_id = ref_captions[audio_idx]["audio_id"]
114
+ key2refs[audio_id] = []
115
+ for caption in ref_captions[audio_idx]["captions"]:
116
+ key2refs[audio_id].append(caption["caption"])
117
+
118
+ pred_captions = json.load(open(prediction, "r"))["predictions"]
119
+
120
+ key2pred = {}
121
+ for audio_idx in range(len(pred_captions)):
122
+ item = pred_captions[audio_idx]
123
+ audio_id = item["filename"]
124
+ key2pred[audio_id] = [item["tokens"]]
125
+
126
+ from fense.fense import Fense
127
+ scores = {}
128
+ scorer = Fense()
129
+ scores[scorer.method()] = evaluate_prediction(key2pred, key2refs, scorer)
130
+
131
+ refs4eval = {}
132
+ for key, refs in key2refs.items():
133
+ refs4eval[key] = []
134
+ for idx, ref in enumerate(refs):
135
+ refs4eval[key].append({
136
+ "audio_id": key,
137
+ "id": idx,
138
+ "caption": ref
139
+ })
140
+
141
+ preds4eval = {}
142
+ for key, preds in key2pred.items():
143
+ preds4eval[key] = []
144
+ for idx, pred in enumerate(preds):
145
+ preds4eval[key].append({
146
+ "audio_id": key,
147
+ "id": idx,
148
+ "caption": pred
149
+ })
150
+
151
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
152
+
153
+ tokenizer = PTBTokenizer()
154
+ key2refs = tokenizer.tokenize(refs4eval)
155
+ key2pred = tokenizer.tokenize(preds4eval)
156
+
157
+
158
+ from pycocoevalcap.bleu.bleu import Bleu
159
+ from pycocoevalcap.cider.cider import Cider
160
+ from pycocoevalcap.rouge.rouge import Rouge
161
+ from pycocoevalcap.meteor.meteor import Meteor
162
+ from pycocoevalcap.spice.spice import Spice
163
+
164
+ scorers = [Bleu(), Rouge(), Cider(), Meteor(), Spice()]
165
+ for scorer in scorers:
166
+ scores[scorer.method()] = evaluate_prediction(key2pred, key2refs, scorer)
167
+
168
+ spider = 0
169
+ with open(output, "w") as f:
170
+ for name, score in scores.items():
171
+ if name == "Bleu":
172
+ for n in range(4):
173
+ f.write("Bleu-{}: {:6.3f}\n".format(n + 1, score[n]))
174
+ else:
175
+ f.write("{}: {:6.3f}\n".format(name, score))
176
+ if name in ["CIDEr", "SPICE"]:
177
+ spider += score
178
+ f.write("SPIDEr: {:6.3f}\n".format(spider / 2))
179
+
180
+
181
+ if __name__ == "__main__":
182
+ fire.Fire(Evaluator)
audio_to_text/captioning/utils/fasttext/create_word_embedding.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #!/usr/bin/env python3
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from gensim.models import FastText
8
+ from tqdm import tqdm
9
+ import fire
10
+
11
+ import sys
12
+ import os
13
+ sys.path.append(os.getcwd())
14
+ from utils.build_vocab import Vocabulary
15
+
16
+ def create_embedding(caption_file: str,
17
+ vocab_file: str,
18
+ embed_size: int,
19
+ output: str,
20
+ **fasttext_kwargs):
21
+ caption_df = pd.read_json(caption_file)
22
+ caption_df["tokens"] = caption_df["tokens"].apply(lambda x: ["<start>"] + [token for token in x] + ["<end>"])
23
+
24
+ sentences = list(caption_df["tokens"].values)
25
+ vocabulary = torch.load(vocab_file, map_location="cpu")
26
+
27
+ epochs = fasttext_kwargs.get("epochs", 10)
28
+ model = FastText(size=embed_size, min_count=1, **fasttext_kwargs)
29
+ model.build_vocab(sentences=sentences)
30
+ model.train(sentences=sentences, total_examples=len(sentences), epochs=epochs)
31
+
32
+ word_embeddings = np.zeros((len(vocabulary), embed_size))
33
+
34
+ with tqdm(total=len(vocabulary), ascii=True) as pbar:
35
+ for word, idx in vocabulary.word2idx.items():
36
+ if word == "<pad>" or word == "<unk>":
37
+ continue
38
+ word_embeddings[idx] = model.wv[word]
39
+ pbar.update()
40
+
41
+ np.save(output, word_embeddings)
42
+
43
+ print("Finish writing fasttext embeddings to " + output)
44
+
45
+
46
+ if __name__ == "__main__":
47
+ fire.Fire(create_embedding)
48
+
49
+
50
+
audio_to_text/captioning/utils/lr_scheduler.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+
5
+ class ExponentialDecayScheduler(torch.optim.lr_scheduler._LRScheduler):
6
+
7
+ def __init__(self, optimizer, total_iters, final_lrs,
8
+ warmup_iters=3000, last_epoch=-1, verbose=False):
9
+ self.total_iters = total_iters
10
+ self.final_lrs = final_lrs
11
+ if not isinstance(self.final_lrs, list) and not isinstance(
12
+ self.final_lrs, tuple):
13
+ self.final_lrs = [self.final_lrs] * len(optimizer.param_groups)
14
+ self.warmup_iters = warmup_iters
15
+ self.bases = [0.0,] * len(optimizer.param_groups)
16
+ super().__init__(optimizer, last_epoch, verbose)
17
+ for i, (base_lr, final_lr) in enumerate(zip(self.base_lrs, self.final_lrs)):
18
+ base = (final_lr / base_lr) ** (1 / (
19
+ self.total_iters - self.warmup_iters))
20
+ self.bases[i] = base
21
+
22
+ def _get_closed_form_lr(self):
23
+ warmup_coeff = 1.0
24
+ current_iter = self._step_count
25
+ if current_iter < self.warmup_iters:
26
+ warmup_coeff = current_iter / self.warmup_iters
27
+ current_lrs = []
28
+ # if not self.linear_warmup:
29
+ # for base_lr, final_lr, base in zip(self.base_lrs, self.final_lrs, self.bases):
30
+ # # current_lr = warmup_coeff * base_lr * math.exp(((current_iter - self.warmup_iters) / self.total_iters) * math.log(final_lr / base_lr))
31
+ # current_lr = warmup_coeff * base_lr * (base ** (current_iter - self.warmup_iters))
32
+ # current_lrs.append(current_lr)
33
+ # else:
34
+ for base_lr, final_lr, base in zip(self.base_lrs, self.final_lrs,
35
+ self.bases):
36
+ if current_iter <= self.warmup_iters:
37
+ current_lr = warmup_coeff * base_lr
38
+ else:
39
+ # current_lr = warmup_coeff * base_lr * math.exp(((current_iter - self.warmup_iters) / self.total_iters) * math.log(final_lr / base_lr))
40
+ current_lr = base_lr * (base ** (current_iter - self.warmup_iters))
41
+ current_lrs.append(current_lr)
42
+ return current_lrs
43
+
44
+ def get_lr(self):
45
+ return self._get_closed_form_lr()
46
+
47
+
48
+ class NoamScheduler(torch.optim.lr_scheduler._LRScheduler):
49
+
50
+ def __init__(self, optimizer, model_size=512, factor=1, warmup_iters=3000,
51
+ last_epoch=-1, verbose=False):
52
+ self.model_size = model_size
53
+ self.warmup_iters = warmup_iters
54
+ # self.factors = [group["lr"] / (self.model_size ** (-0.5) * self.warmup_iters ** (-0.5)) for group in optimizer.param_groups]
55
+ self.factor = factor
56
+ super().__init__(optimizer, last_epoch, verbose)
57
+
58
+ def _get_closed_form_lr(self):
59
+ current_iter = self._step_count
60
+ current_lrs = []
61
+ for _ in self.base_lrs:
62
+ current_lr = self.factor * \
63
+ (self.model_size ** (-0.5) * min(current_iter ** (-0.5),
64
+ current_iter * self.warmup_iters ** (-1.5)))
65
+ current_lrs.append(current_lr)
66
+ return current_lrs
67
+
68
+ def get_lr(self):
69
+ return self._get_closed_form_lr()
70
+
71
+
72
+ class CosineWithWarmup(torch.optim.lr_scheduler._LRScheduler):
73
+
74
+ def __init__(self, optimizer, total_iters, warmup_iters,
75
+ num_cycles=0.5, last_epoch=-1, verbose=False):
76
+ self.total_iters = total_iters
77
+ self.warmup_iters = warmup_iters
78
+ self.num_cycles = num_cycles
79
+ super().__init__(optimizer, last_epoch, verbose)
80
+
81
+ def lr_lambda(self, iteration):
82
+ if iteration < self.warmup_iters:
83
+ return float(iteration) / float(max(1, self.warmup_iters))
84
+ progress = float(iteration - self.warmup_iters) / float(max(1,
85
+ self.total_iters - self.warmup_iters))
86
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(
87
+ self.num_cycles) * 2.0 * progress)))
88
+
89
+ def _get_closed_form_lr(self):
90
+ current_iter = self._step_count
91
+ current_lrs = []
92
+ for base_lr in self.base_lrs:
93
+ current_lr = base_lr * self.lr_lambda(current_iter)
94
+ current_lrs.append(current_lr)
95
+ return current_lrs
96
+
97
+ def get_lr(self):
98
+ return self._get_closed_form_lr()
99
+
100
+
101
+ if __name__ == "__main__":
102
+ model = torch.nn.Linear(10, 5)
103
+ optimizer = torch.optim.Adam(model.parameters(), 5e-4)
104
+ epochs = 25
105
+ iters = 600
106
+ scheduler = CosineWithWarmup(optimizer, 600 * 25, 600 * 5,)
107
+ # scheduler = ExponentialDecayScheduler(optimizer, 600 * 25, 5e-7, 600 * 5)
108
+ criterion = torch.nn.MSELoss()
109
+ lrs = []
110
+ for epoch in range(1, epochs + 1):
111
+ for iteration in range(1, iters + 1):
112
+ optimizer.zero_grad()
113
+ x = torch.randn(4, 10)
114
+ y = torch.randn(4, 5)
115
+ loss = criterion(model(x), y)
116
+ loss.backward()
117
+ optimizer.step()
118
+ scheduler.step()
119
+ # print(f"lr: {scheduler.get_last_lr()}")
120
+ # lrs.append(scheduler.get_last_lr())
121
+ lrs.append(optimizer.param_groups[0]["lr"])
122
+ import matplotlib.pyplot as plt
123
+ plt.plot(list(range(1, len(lrs) + 1)), lrs, '-o', markersize=1)
124
+ # plt.legend(loc="best")
125
+ plt.xlabel("Iteration")
126
+ plt.ylabel("LR")
127
+
128
+ plt.savefig("lr_curve.png", dpi=100)
audio_to_text/captioning/utils/model_eval_diff.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import copy
4
+ import pickle
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import fire
9
+
10
+ sys.path.append(os.getcwd())
11
+
12
+
13
+ def coco_score(refs, pred, scorer):
14
+ if scorer.method() == "Bleu":
15
+ scores = np.array([ 0.0 for n in range(4) ])
16
+ else:
17
+ scores = 0
18
+ num_cap_per_audio = len(refs[list(refs.keys())[0]])
19
+
20
+ for i in range(num_cap_per_audio):
21
+ if i > 0:
22
+ for key in refs:
23
+ refs[key].insert(0, res[key][0])
24
+ res = {key: [refs[key].pop(),] for key in refs}
25
+ score, _ = scorer.compute_score(refs, pred)
26
+
27
+ if scorer.method() == "Bleu":
28
+ scores += np.array(score)
29
+ else:
30
+ scores += score
31
+
32
+ score = scores / num_cap_per_audio
33
+
34
+ for key in refs:
35
+ refs[key].insert(0, res[key][0])
36
+ score_allref, _ = scorer.compute_score(refs, pred)
37
+ diff = score_allref - score
38
+ return diff
39
+
40
+ def embedding_score(refs, pred, scorer):
41
+
42
+ num_cap_per_audio = len(refs[list(refs.keys())[0]])
43
+ scores = 0
44
+
45
+ for i in range(num_cap_per_audio):
46
+ res = {key: [refs[key][i],] for key in refs.keys() if len(refs[key]) == num_cap_per_audio}
47
+ refs_i = {key: np.concatenate([refs[key][:i], refs[key][i+1:]]) for key in refs.keys() if len(refs[key]) == num_cap_per_audio}
48
+ score, _ = scorer.compute_score(refs_i, pred)
49
+
50
+ scores += score
51
+
52
+ score = scores / num_cap_per_audio
53
+
54
+ score_allref, _ = scorer.compute_score(refs, pred)
55
+ diff = score_allref - score
56
+ return diff
57
+
58
+ def main(output_file, eval_caption_file, eval_embedding_file, output, zh=False):
59
+ output_df = pd.read_json(output_file)
60
+ output_df["key"] = output_df["filename"].apply(lambda x: os.path.splitext(os.path.basename(x))[0])
61
+ pred = output_df.groupby("key")["tokens"].apply(list).to_dict()
62
+
63
+ label_df = pd.read_json(eval_caption_file)
64
+ if zh:
65
+ refs = label_df.groupby("key")["tokens"].apply(list).to_dict()
66
+ else:
67
+ refs = label_df.groupby("key")["caption"].apply(list).to_dict()
68
+
69
+ from pycocoevalcap.bleu.bleu import Bleu
70
+ from pycocoevalcap.cider.cider import Cider
71
+ from pycocoevalcap.rouge.rouge import Rouge
72
+
73
+ scorer = Bleu(zh=zh)
74
+ bleu_scores = coco_score(copy.deepcopy(refs), pred, scorer)
75
+ scorer = Cider(zh=zh)
76
+ cider_score = coco_score(copy.deepcopy(refs), pred, scorer)
77
+ scorer = Rouge(zh=zh)
78
+ rouge_score = coco_score(copy.deepcopy(refs), pred, scorer)
79
+
80
+ if not zh:
81
+ from pycocoevalcap.meteor.meteor import Meteor
82
+ scorer = Meteor()
83
+ meteor_score = coco_score(copy.deepcopy(refs), pred, scorer)
84
+
85
+ from pycocoevalcap.spice.spice import Spice
86
+ scorer = Spice()
87
+ spice_score = coco_score(copy.deepcopy(refs), pred, scorer)
88
+
89
+ # from audiocaptioneval.sentbert.sentencebert import SentenceBert
90
+ # scorer = SentenceBert(zh=zh)
91
+ # with open(eval_embedding_file, "rb") as f:
92
+ # ref_embeddings = pickle.load(f)
93
+
94
+ # sent_bert = embedding_score(ref_embeddings, pred, scorer)
95
+
96
+ with open(output, "w") as f:
97
+ f.write("Diff:\n")
98
+ for n in range(4):
99
+ f.write("BLEU-{}: {:6.3f}\n".format(n+1, bleu_scores[n]))
100
+ f.write("CIDEr: {:6.3f}\n".format(cider_score))
101
+ f.write("ROUGE: {:6.3f}\n".format(rouge_score))
102
+ if not zh:
103
+ f.write("Meteor: {:6.3f}\n".format(meteor_score))
104
+ f.write("SPICE: {:6.3f}\n".format(spice_score))
105
+ # f.write("SentenceBert: {:6.3f}\n".format(sent_bert))
106
+
107
+
108
+
109
+ if __name__ == "__main__":
110
+ fire.Fire(main)
audio_to_text/captioning/utils/predict_nn.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import argparse
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from h5py import File
7
+ import sklearn.metrics
8
+
9
+ random.seed(1)
10
+
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("train_feature", type=str)
13
+ parser.add_argument("train_corpus", type=str)
14
+ parser.add_argument("pred_feature", type=str)
15
+ parser.add_argument("output_json", type=str)
16
+
17
+ args = parser.parse_args()
18
+ train_embs = []
19
+ train_idx_to_audioid = []
20
+ with File(args.train_feature, "r") as store:
21
+ for audio_id, embedding in tqdm(store.items(), ascii=True):
22
+ train_embs.append(embedding[()])
23
+ train_idx_to_audioid.append(audio_id)
24
+
25
+ train_annotation = json.load(open(args.train_corpus, "r"))["audios"]
26
+ train_audioid_to_tokens = {}
27
+ for item in train_annotation:
28
+ audio_id = item["audio_id"]
29
+ train_audioid_to_tokens[audio_id] = [cap_item["tokens"] for cap_item in item["captions"]]
30
+ train_embs = np.stack(train_embs)
31
+
32
+
33
+ pred_data = []
34
+ pred_embs = []
35
+ pred_idx_to_audioids = []
36
+ with File(args.pred_feature, "r") as store:
37
+ for audio_id, embedding in tqdm(store.items(), ascii=True):
38
+ pred_embs.append(embedding[()])
39
+ pred_idx_to_audioids.append(audio_id)
40
+ pred_embs = np.stack(pred_embs)
41
+
42
+ similarity = sklearn.metrics.pairwise.cosine_similarity(pred_embs, train_embs)
43
+ for idx, audio_id in enumerate(pred_idx_to_audioids):
44
+ train_idx = similarity[idx].argmax()
45
+ pred_data.append({
46
+ "filename": audio_id,
47
+ "tokens": random.choice(train_audioid_to_tokens[train_idx_to_audioid[train_idx]])
48
+ })
49
+ json.dump({"predictions": pred_data}, open(args.output_json, "w"), ensure_ascii=False, indent=4)
audio_to_text/captioning/utils/remove_optimizer.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+
5
+ def main(checkpoint):
6
+ state_dict = torch.load(checkpoint, map_location="cpu")
7
+ if "optimizer" in state_dict:
8
+ del state_dict["optimizer"]
9
+ if "lr_scheduler" in state_dict:
10
+ del state_dict["lr_scheduler"]
11
+ torch.save(state_dict, checkpoint)
12
+
13
+
14
+ if __name__ == "__main__":
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument("checkpoint", type=str)
17
+ args = parser.parse_args()
18
+ main(args.checkpoint)
audio_to_text/captioning/utils/report_results.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import argparse
3
+ import numpy as np
4
+
5
+ parser = argparse.ArgumentParser()
6
+ parser.add_argument("--input", help="input filename", type=str, nargs="+")
7
+ parser.add_argument("--output", help="output result file", default=None)
8
+
9
+ args = parser.parse_args()
10
+
11
+
12
+ scores = {}
13
+ for path in args.input:
14
+ with open(path, "r") as reader:
15
+ for line in reader.readlines():
16
+ metric, score = line.strip().split(": ")
17
+ score = float(score)
18
+ if metric not in scores:
19
+ scores[metric] = []
20
+ scores[metric].append(score)
21
+
22
+ if len(scores) == 0:
23
+ print("No experiment directory found, wrong path?")
24
+ exit(1)
25
+
26
+ with open(args.output, "w") as writer:
27
+ print("Average results: ", file=writer)
28
+ for metric, score in scores.items():
29
+ score = np.array(score)
30
+ mean = np.mean(score)
31
+ std = np.std(score)
32
+ print(f"{metric}: {mean:.3f} (±{std:.3f})", file=writer)
33
+ print("", file=writer)
34
+ print("Best results: ", file=writer)
35
+ for metric, score in scores.items():
36
+ score = np.max(score)
37
+ print(f"{metric}: {score:.3f}", file=writer)
audio_to_text/captioning/utils/tokenize_caption.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from tqdm import tqdm
3
+ import re
4
+ import fire
5
+
6
+
7
+ def tokenize_caption(input_json: str,
8
+ keep_punctuation: bool = False,
9
+ host_address: str = None,
10
+ character_level: bool = False,
11
+ zh: bool = True,
12
+ output_json: str = None):
13
+ """Build vocabulary from csv file with a given threshold to drop all counts < threshold
14
+
15
+ Args:
16
+ input_json(string): Preprossessed json file. Structure like this:
17
+ {
18
+ 'audios': [
19
+ {
20
+ 'audio_id': 'xxx',
21
+ 'captions': [
22
+ {
23
+ 'caption': 'xxx',
24
+ 'cap_id': 'xxx'
25
+ }
26
+ ]
27
+ },
28
+ ...
29
+ ]
30
+ }
31
+ threshold (int): Threshold to drop all words with counts < threshold
32
+ keep_punctuation (bool): Includes or excludes punctuation.
33
+
34
+ Returns:
35
+ vocab (Vocab): Object with the processed vocabulary
36
+ """
37
+ data = json.load(open(input_json, "r"))["audios"]
38
+
39
+ if zh:
40
+ from nltk.parse.corenlp import CoreNLPParser
41
+ from zhon.hanzi import punctuation
42
+ parser = CoreNLPParser(host_address)
43
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
44
+ for cap_idx in range(len(data[audio_idx]["captions"])):
45
+ caption = data[audio_idx]["captions"][cap_idx]["caption"]
46
+ # Remove all punctuations
47
+ if not keep_punctuation:
48
+ caption = re.sub("[{}]".format(punctuation), "", caption)
49
+ if character_level:
50
+ tokens = list(caption)
51
+ else:
52
+ tokens = list(parser.tokenize(caption))
53
+ data[audio_idx]["captions"][cap_idx]["tokens"] = " ".join(tokens)
54
+ else:
55
+ from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
56
+ captions = {}
57
+ for audio_idx in range(len(data)):
58
+ audio_id = data[audio_idx]["audio_id"]
59
+ captions[audio_id] = []
60
+ for cap_idx in range(len(data[audio_idx]["captions"])):
61
+ caption = data[audio_idx]["captions"][cap_idx]["caption"]
62
+ captions[audio_id].append({
63
+ "audio_id": audio_id,
64
+ "id": cap_idx,
65
+ "caption": caption
66
+ })
67
+ tokenizer = PTBTokenizer()
68
+ captions = tokenizer.tokenize(captions)
69
+ for audio_idx in tqdm(range(len(data)), leave=False, ascii=True):
70
+ audio_id = data[audio_idx]["audio_id"]
71
+ for cap_idx in range(len(data[audio_idx]["captions"])):
72
+ tokens = captions[audio_id][cap_idx]
73
+ data[audio_idx]["captions"][cap_idx]["tokens"] = tokens
74
+
75
+ if output_json:
76
+ json.dump(
77
+ { "audios": data }, open(output_json, "w"),
78
+ indent=4, ensure_ascii=not zh)
79
+ else:
80
+ json.dump(
81
+ { "audios": data }, open(input_json, "w"),
82
+ indent=4, ensure_ascii=not zh)
83
+
84
+
85
+ if __name__ == "__main__":
86
+ fire.Fire(tokenize_caption)
audio_to_text/captioning/utils/train_util.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #!/usr/bin/env python3
3
+ import os
4
+ import sys
5
+ import logging
6
+ from typing import Callable, Dict, Union
7
+ import yaml
8
+ import torch
9
+ from torch.optim.swa_utils import AveragedModel as torch_average_model
10
+ import numpy as np
11
+ import pandas as pd
12
+ from pprint import pformat
13
+
14
+
15
+ def load_dict_from_csv(csv, cols):
16
+ df = pd.read_csv(csv, sep="\t")
17
+ output = dict(zip(df[cols[0]], df[cols[1]]))
18
+ return output
19
+
20
+
21
+ def init_logger(filename, level="INFO"):
22
+ formatter = logging.Formatter(
23
+ "[ %(levelname)s : %(asctime)s ] - %(message)s")
24
+ logger = logging.getLogger(__name__ + "." + filename)
25
+ logger.setLevel(getattr(logging, level))
26
+ # Log results to std
27
+ # stdhandler = logging.StreamHandler(sys.stdout)
28
+ # stdhandler.setFormatter(formatter)
29
+ # Dump log to file
30
+ filehandler = logging.FileHandler(filename)
31
+ filehandler.setFormatter(formatter)
32
+ logger.addHandler(filehandler)
33
+ # logger.addHandler(stdhandler)
34
+ return logger
35
+
36
+
37
+ def init_obj(module, config, **kwargs):# 'captioning.models.encoder'
38
+ obj_args = config["args"].copy()
39
+ obj_args.update(kwargs)
40
+ return getattr(module, config["type"])(**obj_args)
41
+
42
+
43
+ def pprint_dict(in_dict, outputfun=sys.stdout.write, formatter='yaml'):
44
+ """pprint_dict
45
+
46
+ :param outputfun: function to use, defaults to sys.stdout
47
+ :param in_dict: dict to print
48
+ """
49
+ if formatter == 'yaml':
50
+ format_fun = yaml.dump
51
+ elif formatter == 'pretty':
52
+ format_fun = pformat
53
+ for line in format_fun(in_dict).split('\n'):
54
+ outputfun(line)
55
+
56
+
57
+ def merge_a_into_b(a, b):
58
+ # merge dict a into dict b. values in a will overwrite b.
59
+ for k, v in a.items():
60
+ if isinstance(v, dict) and k in b:
61
+ assert isinstance(
62
+ b[k], dict
63
+ ), "Cannot inherit key '{}' from base!".format(k)
64
+ merge_a_into_b(v, b[k])
65
+ else:
66
+ b[k] = v
67
+
68
+
69
+ def load_config(config_file):
70
+ with open(config_file, "r") as reader:
71
+ config = yaml.load(reader, Loader=yaml.FullLoader)
72
+ if "inherit_from" in config:
73
+ base_config_file = config["inherit_from"]
74
+ base_config_file = os.path.join(
75
+ os.path.dirname(config_file), base_config_file
76
+ )
77
+ assert not os.path.samefile(config_file, base_config_file), \
78
+ "inherit from itself"
79
+ base_config = load_config(base_config_file)
80
+ del config["inherit_from"]
81
+ merge_a_into_b(config, base_config)
82
+ return base_config
83
+ return config
84
+
85
+
86
+ def parse_config_or_kwargs(config_file, **kwargs):
87
+ yaml_config = load_config(config_file)
88
+ # passed kwargs will override yaml config
89
+ args = dict(yaml_config, **kwargs)
90
+ return args
91
+
92
+
93
+ def store_yaml(config, config_file):
94
+ with open(config_file, "w") as con_writer:
95
+ yaml.dump(config, con_writer, indent=4, default_flow_style=False)
96
+
97
+
98
+ class MetricImprover:
99
+
100
+ def __init__(self, mode):
101
+ assert mode in ("min", "max")
102
+ self.mode = mode
103
+ # min: lower -> better; max: higher -> better
104
+ self.best_value = np.inf if mode == "min" else -np.inf
105
+
106
+ def compare(self, x, best_x):
107
+ return x < best_x if self.mode == "min" else x > best_x
108
+
109
+ def __call__(self, x):
110
+ if self.compare(x, self.best_value):
111
+ self.best_value = x
112
+ return True
113
+ return False
114
+
115
+ def state_dict(self):
116
+ return self.__dict__
117
+
118
+ def load_state_dict(self, state_dict):
119
+ self.__dict__.update(state_dict)
120
+
121
+
122
+ def fix_batchnorm(model: torch.nn.Module):
123
+ def inner(module):
124
+ class_name = module.__class__.__name__
125
+ if class_name.find("BatchNorm") != -1:
126
+ module.eval()
127
+ model.apply(inner)
128
+
129
+
130
+ def load_pretrained_model(model: torch.nn.Module,
131
+ pretrained: Union[str, Dict],
132
+ output_fn: Callable = sys.stdout.write):
133
+ if not isinstance(pretrained, dict) and not os.path.exists(pretrained):
134
+ output_fn(f"pretrained {pretrained} not exist!")
135
+ return
136
+
137
+ if hasattr(model, "load_pretrained"):
138
+ model.load_pretrained(pretrained)
139
+ return
140
+
141
+ if isinstance(pretrained, dict):
142
+ state_dict = pretrained
143
+ else:
144
+ state_dict = torch.load(pretrained, map_location="cpu")
145
+
146
+ if "model" in state_dict:
147
+ state_dict = state_dict["model"]
148
+ model_dict = model.state_dict()
149
+ pretrained_dict = {
150
+ k: v for k, v in state_dict.items() if (k in model_dict) and (
151
+ model_dict[k].shape == v.shape)
152
+ }
153
+ output_fn(f"Loading pretrained keys {pretrained_dict.keys()}")
154
+ model_dict.update(pretrained_dict)
155
+ model.load_state_dict(model_dict, strict=True)
156
+
157
+
158
+ class AveragedModel(torch_average_model):
159
+
160
+ def update_parameters(self, model):
161
+ for p_swa, p_model in zip(self.parameters(), model.parameters()):
162
+ device = p_swa.device
163
+ p_model_ = p_model.detach().to(device)
164
+ if self.n_averaged == 0:
165
+ p_swa.detach().copy_(p_model_)
166
+ else:
167
+ p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
168
+ self.n_averaged.to(device)))
169
+
170
+ for b_swa, b_model in zip(list(self.buffers())[1:], model.buffers()):
171
+ device = b_swa.device
172
+ b_model_ = b_model.detach().to(device)
173
+ if self.n_averaged == 0:
174
+ b_swa.detach().copy_(b_model_)
175
+ else:
176
+ b_swa.detach().copy_(self.avg_fn(b_swa.detach(), b_model_,
177
+ self.n_averaged.to(device)))
178
+ self.n_averaged += 1
audio_to_text/captioning/utils/word2vec/create_word_embedding.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #!/usr/bin/env python3
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ import gensim
8
+ from gensim.models import Word2Vec
9
+ from tqdm import tqdm
10
+ import fire
11
+
12
+ import sys
13
+ import os
14
+ sys.path.append(os.getcwd())
15
+ from utils.build_vocab import Vocabulary
16
+
17
+ def create_embedding(vocab_file: str,
18
+ embed_size: int,
19
+ output: str,
20
+ caption_file: str = None,
21
+ pretrained_weights_path: str = None,
22
+ **word2vec_kwargs):
23
+ vocabulary = torch.load(vocab_file, map_location="cpu")
24
+
25
+ if pretrained_weights_path:
26
+ model = gensim.models.KeyedVectors.load_word2vec_format(
27
+ fname=pretrained_weights_path,
28
+ binary=True,
29
+ )
30
+ if model.vector_size != embed_size:
31
+ assert embed_size < model.vector_size, f"only reduce dimension, cannot add dimesion {model.vector_size} to {embed_size}"
32
+ from sklearn.decomposition import PCA
33
+ pca = PCA(n_components=embed_size)
34
+ model.vectors = pca.fit_transform(model.vectors)
35
+ else:
36
+ caption_df = pd.read_json(caption_file)
37
+ caption_df["tokens"] = caption_df["tokens"].apply(lambda x: ["<start>"] + [token for token in x] + ["<end>"])
38
+ sentences = list(caption_df["tokens"].values)
39
+ epochs = word2vec_kwargs.get("epochs", 10)
40
+ if "epochs" in word2vec_kwargs:
41
+ del word2vec_kwargs["epochs"]
42
+ model = Word2Vec(size=embed_size, min_count=1, **word2vec_kwargs)
43
+ model.build_vocab(sentences=sentences)
44
+ model.train(sentences=sentences, total_examples=len(sentences), epochs=epochs)
45
+
46
+ word_embeddings = np.random.randn(len(vocabulary), embed_size)
47
+
48
+ if isinstance(model, gensim.models.word2vec.Word2Vec):
49
+ model = model.wv
50
+ with tqdm(total=len(vocabulary), ascii=True) as pbar:
51
+ for word, idx in vocabulary.word2idx.items():
52
+ try:
53
+ word_embeddings[idx] = model.get_vector(word)
54
+ except KeyError:
55
+ print(f"word {word} not found in word2vec model, it is random initialized!")
56
+ pbar.update()
57
+
58
+ np.save(output, word_embeddings)
59
+
60
+ print("Finish writing word2vec embeddings to " + output)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ fire.Fire(create_embedding)
65
+
66
+
67
+
audio_to_text/clotho_cntrstv_cnn14rnn_trm/config.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ encoder:
3
+ type: Cnn14RnnEncoder
4
+ args:
5
+ sample_rate: 32000
6
+ pretrained: ./audio_to_text/pretrained_feature_extractors/contrastive_pretrain_cnn14_bertm.pth
7
+ freeze_cnn: True
8
+ freeze_cnn_bn: True
9
+ bidirectional: True
10
+ dropout: 0.5
11
+ hidden_size: 256
12
+ num_layers: 3
13
+ decoder:
14
+ type: TransformerDecoder
15
+ args:
16
+ attn_emb_dim: 512
17
+ dropout: 0.2
18
+ emb_dim: 256
19
+ fc_emb_dim: 512
20
+ nlayers: 2
21
+ type: TransformerModel
22
+ args: {}
audio_to_text/clotho_cntrstv_cnn14rnn_trm/swa.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8d341dccafcdcfb7009c402afb07f314ab1d613a5f5c42d32407d6c2a821abf
3
+ size 41755865
audio_to_text/inference_waveform.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import librosa
4
+ import numpy as np
5
+ import torch
6
+ import audio_to_text.captioning.models
7
+ import audio_to_text.captioning.models.encoder
8
+ import audio_to_text.captioning.models.decoder
9
+ import audio_to_text.captioning.utils.train_util as train_util
10
+
11
+
12
+ def load_model(config, checkpoint):
13
+ ckpt = torch.load(checkpoint, "cpu")
14
+ encoder_cfg = config["model"]["encoder"]
15
+ encoder = train_util.init_obj(
16
+ audio_to_text.captioning.models.encoder,
17
+ encoder_cfg
18
+ )
19
+ if "pretrained" in encoder_cfg:
20
+ pretrained = encoder_cfg["pretrained"]
21
+ train_util.load_pretrained_model(encoder,
22
+ pretrained,
23
+ sys.stdout.write)
24
+ decoder_cfg = config["model"]["decoder"]
25
+ if "vocab_size" not in decoder_cfg["args"]:
26
+ decoder_cfg["args"]["vocab_size"] = len(ckpt["vocabulary"])
27
+ decoder = train_util.init_obj(
28
+ audio_to_text.captioning.models.decoder,
29
+ decoder_cfg
30
+ )
31
+ if "word_embedding" in decoder_cfg:
32
+ decoder.load_word_embedding(**decoder_cfg["word_embedding"])
33
+ if "pretrained" in decoder_cfg:
34
+ pretrained = decoder_cfg["pretrained"]
35
+ train_util.load_pretrained_model(decoder,
36
+ pretrained,
37
+ sys.stdout.write)
38
+ model = train_util.init_obj(audio_to_text.captioning.models, config["model"],
39
+ encoder=encoder, decoder=decoder)
40
+ train_util.load_pretrained_model(model, ckpt)
41
+ model.eval()
42
+ return {
43
+ "model": model,
44
+ "vocabulary": ckpt["vocabulary"]
45
+ }
46
+
47
+
48
+ def decode_caption(word_ids, vocabulary):
49
+ candidate = []
50
+ for word_id in word_ids:
51
+ word = vocabulary[word_id]
52
+ if word == "<end>":
53
+ break
54
+ elif word == "<start>":
55
+ continue
56
+ candidate.append(word)
57
+ candidate = " ".join(candidate)
58
+ return candidate
59
+
60
+
61
+ class AudioCapModel(object):
62
+ def __init__(self,weight_dir,device='cuda'):
63
+ config = os.path.join(weight_dir,'config.yaml')
64
+ self.config = train_util.parse_config_or_kwargs(config)
65
+ checkpoint = os.path.join(weight_dir,'swa.pth')
66
+ resumed = load_model(self.config, checkpoint)
67
+ model = resumed["model"]
68
+ self.vocabulary = resumed["vocabulary"]
69
+ self.model = model.to(device)
70
+ self.device = device
71
+
72
+ def caption(self,audio_list):
73
+ if isinstance(audio_list,np.ndarray):
74
+ audio_list = [audio_list]
75
+ elif isinstance(audio_list,str):
76
+ audio_list = [librosa.load(audio_list,sr=32000)[0]]
77
+
78
+ captions = []
79
+ for wav in audio_list:
80
+ inputwav = torch.as_tensor(wav).float().unsqueeze(0).to(self.device)
81
+ wav_len = torch.LongTensor([len(wav)])
82
+ input_dict = {
83
+ "mode": "inference",
84
+ "wav": inputwav,
85
+ "wav_len": wav_len,
86
+ "specaug": False,
87
+ "sample_method": "beam",
88
+ }
89
+ print(input_dict)
90
+ out_dict = self.model(input_dict)
91
+ caption_batch = [decode_caption(seq, self.vocabulary) for seq in \
92
+ out_dict["seq"].cpu().numpy()]
93
+ captions.extend(caption_batch)
94
+ return captions
95
+
96
+
97
+
98
+ def __call__(self, audio_list):
99
+ return self.caption(audio_list)
100
+
101
+
102
+
audio_to_text/pretrained_feature_extractors/contrastive_pretrain_cnn14_bertm.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c4faa86f30e77df235b5dc1fb6578a18ff2b8a1b0043f47e30acb9ccb53a336
3
+ size 494977221