WillHeld commited on
Commit
f3ba0cb
1 Parent(s): a4c82a3

SALMONN CODE

Browse files
models/__init__.py ADDED
File without changes
models/beats/BEATs.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import logging
12
+ from typing import Optional
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torchaudio.compliance.kaldi as ta_kaldi
17
+ from models.beats.backbone import TransformerEncoder
18
+ from torch.nn import LayerNorm
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class BEATsConfig:
24
+ def __init__(self, cfg=None):
25
+ self.input_patch_size: int = -1 # path size of patch embedding
26
+ self.embed_dim: int = 512 # patch embedding dimension
27
+ self.conv_bias: bool = False # include bias in conv encoder
28
+
29
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
30
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
31
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
32
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
33
+ self.activation_fn: str = "gelu" # activation function to use
34
+
35
+ self.layer_wise_gradient_decay_ratio: float = (
36
+ 1.0 # ratio for layer-wise gradient decay
37
+ )
38
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
39
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
40
+
41
+ # dropouts
42
+ self.dropout: float = 0.1 # dropout probability for the transformer
43
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
44
+ self.activation_dropout: float = (
45
+ 0.0 # dropout probability after activation in FFN
46
+ )
47
+ self.encoder_layerdrop: float = (
48
+ 0.0 # probability of dropping a tarnsformer layer
49
+ )
50
+ self.dropout_input: float = (
51
+ 0.0 # dropout to apply to the input (after feat extr)
52
+ )
53
+
54
+ # positional embeddings
55
+ self.conv_pos: int = (
56
+ 128 # number of filters for convolutional positional embeddings
57
+ )
58
+ self.conv_pos_groups: int = (
59
+ 16 # number of groups for convolutional positional embedding
60
+ )
61
+
62
+ # relative position embedding
63
+ self.relative_position_embedding: bool = (
64
+ False # apply relative position embedding
65
+ )
66
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
67
+ self.max_distance: int = (
68
+ 1280 # maximum distance for relative position embedding
69
+ )
70
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
71
+
72
+ # label predictor
73
+ self.finetuned_model: bool = False # whether the model is a fine-tuned model.
74
+ self.predictor_dropout: float = 0.1 # dropout probability for the predictor
75
+ self.predictor_class: int = 527 # target class number for the predictor
76
+
77
+ if cfg is not None:
78
+ self.update(cfg)
79
+
80
+ def update(self, cfg: dict):
81
+ self.__dict__.update(cfg)
82
+
83
+
84
+ class BEATs(nn.Module):
85
+ def __init__(
86
+ self,
87
+ cfg: BEATsConfig,
88
+ ) -> None:
89
+ super().__init__()
90
+ logger.info(f"BEATs Config: {cfg.__dict__}")
91
+
92
+ self.cfg = cfg
93
+
94
+ self.embed = cfg.embed_dim
95
+ self.post_extract_proj = (
96
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
97
+ if self.embed != cfg.encoder_embed_dim
98
+ else None
99
+ )
100
+
101
+ self.input_patch_size = cfg.input_patch_size
102
+ self.patch_embedding = nn.Conv2d(
103
+ 1,
104
+ self.embed,
105
+ kernel_size=self.input_patch_size,
106
+ stride=self.input_patch_size,
107
+ bias=cfg.conv_bias,
108
+ )
109
+
110
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
111
+
112
+ assert not cfg.deep_norm or not cfg.layer_norm_first
113
+ self.encoder = TransformerEncoder(cfg)
114
+ self.layer_norm = LayerNorm(self.embed)
115
+
116
+ if cfg.finetuned_model:
117
+ self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
118
+ self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
119
+ else:
120
+ self.predictor = None
121
+
122
+ def forward_padding_mask(
123
+ self,
124
+ features: torch.Tensor,
125
+ padding_mask: torch.Tensor,
126
+ ) -> torch.Tensor:
127
+ extra = padding_mask.size(1) % features.size(1)
128
+ if extra > 0:
129
+ padding_mask = padding_mask[:, :-extra]
130
+ padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
131
+ padding_mask = padding_mask.all(-1)
132
+ return padding_mask
133
+
134
+ def preprocess(
135
+ self,
136
+ source: torch.Tensor,
137
+ fbank_mean: float = 15.41663,
138
+ fbank_std: float = 6.55582,
139
+ ) -> torch.Tensor:
140
+ fbanks = []
141
+ for waveform in source:
142
+ waveform = waveform.unsqueeze(0) * 2**15
143
+ fbank = ta_kaldi.fbank(
144
+ waveform,
145
+ num_mel_bins=128,
146
+ sample_frequency=16000,
147
+ frame_length=25,
148
+ frame_shift=10,
149
+ )
150
+ fbanks.append(fbank)
151
+ fbank = torch.stack(fbanks, dim=0)
152
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
153
+ return fbank
154
+
155
+ def extract_features(
156
+ self,
157
+ source: torch.Tensor,
158
+ padding_mask: Optional[torch.Tensor] = None,
159
+ fbank_mean: float = 15.41663,
160
+ fbank_std: float = 6.55582,
161
+ feature_only=False,
162
+ ):
163
+ fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(
164
+ torch.float32
165
+ )
166
+
167
+ if padding_mask is not None:
168
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
169
+
170
+ fbank = fbank.unsqueeze(1)
171
+ features = self.patch_embedding(fbank)
172
+ features = features.reshape(features.shape[0], features.shape[1], -1)
173
+ features = features.transpose(1, 2)
174
+ features = self.layer_norm(features)
175
+
176
+ if padding_mask is not None:
177
+ padding_mask = self.forward_padding_mask(features, padding_mask)
178
+
179
+ if self.post_extract_proj is not None:
180
+ features = self.post_extract_proj(features)
181
+
182
+ x = self.dropout_input(features)
183
+
184
+ x, layer_results = self.encoder(
185
+ x,
186
+ padding_mask=padding_mask,
187
+ )
188
+
189
+ if not feature_only and self.predictor is not None:
190
+ x = self.predictor_dropout(x)
191
+ logits = self.predictor(x)
192
+
193
+ if padding_mask is not None and padding_mask.any():
194
+ logits[padding_mask] = 0
195
+ logits = logits.sum(dim=1)
196
+ logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
197
+ logits
198
+ )
199
+ else:
200
+ logits = logits.mean(dim=1)
201
+
202
+ lprobs = torch.sigmoid(logits)
203
+
204
+ return lprobs, padding_mask
205
+ else:
206
+ return x, padding_mask
models/beats/LICENSE_beats ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) Microsoft Corporation
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
models/beats/Tokenizers.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import LayerNorm
14
+ import torchaudio.compliance.kaldi as ta_kaldi
15
+
16
+ from beats.backbone import (
17
+ TransformerEncoder,
18
+ )
19
+ from beats.quantizer import (
20
+ NormEMAVectorQuantizer,
21
+ )
22
+
23
+ import logging
24
+ from typing import Optional
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class TokenizersConfig:
30
+ def __init__(self, cfg=None):
31
+ self.input_patch_size: int = -1 # path size of patch embedding
32
+ self.embed_dim: int = 512 # patch embedding dimension
33
+ self.conv_bias: bool = False # include bias in conv encoder
34
+
35
+ self.encoder_layers: int = 12 # num encoder layers in the transformer
36
+ self.encoder_embed_dim: int = 768 # encoder embedding dimension
37
+ self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
38
+ self.encoder_attention_heads: int = 12 # num encoder attention heads
39
+ self.activation_fn: str = "gelu" # activation function to use
40
+
41
+ self.layer_norm_first: bool = False # apply layernorm first in the transformer
42
+ self.deep_norm: bool = False # apply deep_norm first in the transformer
43
+
44
+ # dropouts
45
+ self.dropout: float = 0.1 # dropout probability for the transformer
46
+ self.attention_dropout: float = 0.1 # dropout probability for attention weights
47
+ self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
48
+ self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
49
+ self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
50
+
51
+ # positional embeddings
52
+ self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
53
+ self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
54
+
55
+ # relative position embedding
56
+ self.relative_position_embedding: bool = False # apply relative position embedding
57
+ self.num_buckets: int = 320 # number of buckets for relative position embedding
58
+ self.max_distance: int = 1280 # maximum distance for relative position embedding
59
+ self.gru_rel_pos: bool = False # apply gated relative position embedding
60
+
61
+ # quantizer
62
+ self.quant_n: int = 1024 # codebook number in quantizer
63
+ self.quant_dim: int = 256 # codebook dimension in quantizer
64
+
65
+ if cfg is not None:
66
+ self.update(cfg)
67
+
68
+ def update(self, cfg: dict):
69
+ self.__dict__.update(cfg)
70
+
71
+
72
+ class Tokenizers(nn.Module):
73
+ def __init__(
74
+ self,
75
+ cfg: TokenizersConfig,
76
+ ) -> None:
77
+ super().__init__()
78
+ logger.info(f"Tokenizers Config: {cfg.__dict__}")
79
+
80
+ self.cfg = cfg
81
+
82
+ self.embed = cfg.embed_dim
83
+ self.post_extract_proj = (
84
+ nn.Linear(self.embed, cfg.encoder_embed_dim)
85
+ if self.embed != cfg.encoder_embed_dim
86
+ else None
87
+ )
88
+
89
+ self.input_patch_size = cfg.input_patch_size
90
+ self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
91
+ bias=cfg.conv_bias)
92
+
93
+ self.dropout_input = nn.Dropout(cfg.dropout_input)
94
+
95
+ assert not cfg.deep_norm or not cfg.layer_norm_first
96
+ self.encoder = TransformerEncoder(cfg)
97
+ self.layer_norm = LayerNorm(self.embed)
98
+
99
+ self.quantize = NormEMAVectorQuantizer(
100
+ n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
101
+ )
102
+ self.quant_n = cfg.quant_n
103
+ self.quantize_layer = nn.Sequential(
104
+ nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
105
+ nn.Tanh(),
106
+ nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
107
+ )
108
+
109
+ def forward_padding_mask(
110
+ self,
111
+ features: torch.Tensor,
112
+ padding_mask: torch.Tensor,
113
+ ) -> torch.Tensor:
114
+ extra = padding_mask.size(1) % features.size(1)
115
+ if extra > 0:
116
+ padding_mask = padding_mask[:, :-extra]
117
+ padding_mask = padding_mask.view(
118
+ padding_mask.size(0), features.size(1), -1
119
+ )
120
+ padding_mask = padding_mask.all(-1)
121
+ return padding_mask
122
+
123
+ def preprocess(
124
+ self,
125
+ source: torch.Tensor,
126
+ fbank_mean: float = 15.41663,
127
+ fbank_std: float = 6.55582,
128
+ ) -> torch.Tensor:
129
+ fbanks = []
130
+ for waveform in source:
131
+ waveform = waveform.unsqueeze(0) * 2 ** 15
132
+ fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
133
+ fbanks.append(fbank)
134
+ fbank = torch.stack(fbanks, dim=0)
135
+ fbank = (fbank - fbank_mean) / (2 * fbank_std)
136
+ return fbank
137
+
138
+ def extract_labels(
139
+ self,
140
+ source: torch.Tensor,
141
+ padding_mask: Optional[torch.Tensor] = None,
142
+ fbank_mean: float = 15.41663,
143
+ fbank_std: float = 6.55582,
144
+ ):
145
+ fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
146
+
147
+ if padding_mask is not None:
148
+ padding_mask = self.forward_padding_mask(fbank, padding_mask)
149
+
150
+ fbank = fbank.unsqueeze(1)
151
+ features = self.patch_embedding(fbank)
152
+ features = features.reshape(features.shape[0], features.shape[1], -1)
153
+ features = features.transpose(1, 2)
154
+ features = self.layer_norm(features)
155
+
156
+ if padding_mask is not None:
157
+ padding_mask = self.forward_padding_mask(features, padding_mask)
158
+
159
+ if self.post_extract_proj is not None:
160
+ features = self.post_extract_proj(features)
161
+
162
+ x = self.dropout_input(features)
163
+
164
+ x, layer_results = self.encoder(
165
+ x,
166
+ padding_mask=padding_mask,
167
+ )
168
+
169
+ quantize_input = self.quantize_layer(x)
170
+ quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
171
+
172
+ return embed_ind
models/beats/__init__.py ADDED
File without changes
models/beats/backbone.py ADDED
@@ -0,0 +1,814 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ from typing import Dict, Optional, Tuple
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from models.beats.modules import (
17
+ GLU_Linear,
18
+ GradMultiply,
19
+ SamePad,
20
+ get_activation_fn,
21
+ quant_noise,
22
+ )
23
+ from torch import Tensor, nn
24
+ from torch.nn import LayerNorm, Parameter
25
+
26
+
27
+ class TransformerEncoder(nn.Module):
28
+ def __init__(self, args):
29
+ super().__init__()
30
+
31
+ self.dropout = args.dropout
32
+ self.embedding_dim = args.encoder_embed_dim
33
+
34
+ self.pos_conv = nn.Conv1d(
35
+ self.embedding_dim,
36
+ self.embedding_dim,
37
+ kernel_size=args.conv_pos,
38
+ padding=args.conv_pos // 2,
39
+ groups=args.conv_pos_groups,
40
+ )
41
+ dropout = 0
42
+ std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
43
+ nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
44
+ nn.init.constant_(self.pos_conv.bias, 0)
45
+
46
+ self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
47
+ self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
48
+
49
+ if hasattr(args, "relative_position_embedding"):
50
+ self.relative_position_embedding = args.relative_position_embedding
51
+ self.num_buckets = args.num_buckets
52
+ self.max_distance = args.max_distance
53
+ else:
54
+ self.relative_position_embedding = False
55
+ self.num_buckets = 0
56
+ self.max_distance = 0
57
+
58
+ self.layers = nn.ModuleList(
59
+ [
60
+ TransformerSentenceEncoderLayer(
61
+ embedding_dim=self.embedding_dim,
62
+ ffn_embedding_dim=args.encoder_ffn_embed_dim,
63
+ num_attention_heads=args.encoder_attention_heads,
64
+ dropout=self.dropout,
65
+ attention_dropout=args.attention_dropout,
66
+ activation_dropout=args.activation_dropout,
67
+ activation_fn=args.activation_fn,
68
+ layer_norm_first=args.layer_norm_first,
69
+ deep_norm=args.deep_norm,
70
+ has_relative_attention_bias=self.relative_position_embedding,
71
+ num_buckets=self.num_buckets,
72
+ max_distance=self.max_distance,
73
+ gru_rel_pos=args.gru_rel_pos,
74
+ encoder_layers=args.encoder_layers,
75
+ )
76
+ for i in range(args.encoder_layers)
77
+ ]
78
+ )
79
+ if self.relative_position_embedding:
80
+ for i in range(1, args.encoder_layers):
81
+ del self.layers[i].self_attn.relative_attention_bias
82
+ self.layers[i].self_attn.relative_attention_bias = self.layers[
83
+ 0
84
+ ].self_attn.relative_attention_bias
85
+
86
+ self.layer_norm_first = args.layer_norm_first
87
+ self.layer_norm = LayerNorm(self.embedding_dim)
88
+ self.layerdrop = args.encoder_layerdrop
89
+
90
+ self.apply(init_bert_params)
91
+
92
+ if args.deep_norm:
93
+ deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
94
+ for i in range(args.encoder_layers):
95
+ nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
96
+ nn.init.xavier_normal_(
97
+ self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta
98
+ )
99
+ nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
100
+ nn.init.xavier_normal_(
101
+ self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta
102
+ )
103
+ nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
104
+ nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
105
+
106
+ self.layer_wise_gradient_decay_ratio = getattr(
107
+ args, "layer_wise_gradient_decay_ratio", 1
108
+ )
109
+
110
+ def forward(self, x, padding_mask=None, layer=None):
111
+ x, layer_results = self.extract_features(x, padding_mask, layer)
112
+
113
+ if self.layer_norm_first and layer is None:
114
+ x = self.layer_norm(x)
115
+
116
+ return x, layer_results
117
+
118
+ def extract_features(self, x, padding_mask=None, tgt_layer=None):
119
+
120
+ if padding_mask is not None:
121
+ x[padding_mask] = 0
122
+
123
+ x_conv = self.pos_conv(x.transpose(1, 2))
124
+ x_conv = x_conv.transpose(1, 2)
125
+ x = x + x_conv
126
+
127
+ if not self.layer_norm_first:
128
+ x = self.layer_norm(x)
129
+
130
+ x = F.dropout(x, p=self.dropout, training=self.training)
131
+
132
+ # B x T x C -> T x B x C
133
+ x = x.transpose(0, 1)
134
+
135
+ layer_results = []
136
+ z = None
137
+ if tgt_layer is not None:
138
+ layer_results.append((x, z))
139
+ r = None
140
+ pos_bias = None
141
+ for i, layer in enumerate(self.layers):
142
+ if self.layer_wise_gradient_decay_ratio != 1.0:
143
+ x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
144
+ dropout_probability = np.random.random()
145
+ if not self.training or (dropout_probability > self.layerdrop):
146
+ x, z, pos_bias = layer(
147
+ x,
148
+ self_attn_padding_mask=padding_mask,
149
+ need_weights=False,
150
+ pos_bias=pos_bias,
151
+ )
152
+ if tgt_layer is not None:
153
+ layer_results.append((x, z))
154
+ if i == tgt_layer:
155
+ r = x
156
+ break
157
+
158
+ if r is not None:
159
+ x = r
160
+
161
+ # T x B x C -> B x T x C
162
+ x = x.transpose(0, 1)
163
+
164
+ return x, layer_results
165
+
166
+
167
+ class TransformerSentenceEncoderLayer(nn.Module):
168
+ def __init__(
169
+ self,
170
+ embedding_dim: float = 768,
171
+ ffn_embedding_dim: float = 3072,
172
+ num_attention_heads: float = 8,
173
+ dropout: float = 0.1,
174
+ attention_dropout: float = 0.1,
175
+ activation_dropout: float = 0.1,
176
+ activation_fn: str = "relu",
177
+ layer_norm_first: bool = False,
178
+ deep_norm: bool = False,
179
+ has_relative_attention_bias: bool = False,
180
+ num_buckets: int = 0,
181
+ max_distance: int = 0,
182
+ rescale_init: bool = False,
183
+ gru_rel_pos: bool = False,
184
+ encoder_layers: int = 0,
185
+ ) -> None:
186
+
187
+ super().__init__()
188
+ self.embedding_dim = embedding_dim
189
+ self.dropout = dropout
190
+ self.activation_dropout = activation_dropout
191
+
192
+ self.activation_name = activation_fn
193
+ self.activation_fn = get_activation_fn(activation_fn)
194
+ self.self_attn = MultiheadAttention(
195
+ self.embedding_dim,
196
+ num_attention_heads,
197
+ dropout=attention_dropout,
198
+ self_attention=True,
199
+ has_relative_attention_bias=has_relative_attention_bias,
200
+ num_buckets=num_buckets,
201
+ max_distance=max_distance,
202
+ rescale_init=rescale_init,
203
+ gru_rel_pos=gru_rel_pos,
204
+ )
205
+
206
+ self.dropout1 = nn.Dropout(dropout)
207
+ self.dropout2 = nn.Dropout(self.activation_dropout)
208
+ self.dropout3 = nn.Dropout(dropout)
209
+
210
+ self.layer_norm_first = layer_norm_first
211
+
212
+ self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
213
+
214
+ if self.activation_name == "glu":
215
+ self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
216
+ else:
217
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
218
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
219
+
220
+ self.final_layer_norm = LayerNorm(self.embedding_dim)
221
+
222
+ self.deep_norm = deep_norm
223
+ if self.deep_norm:
224
+ self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
225
+ else:
226
+ self.deep_norm_alpha = 1
227
+
228
+ def forward(
229
+ self,
230
+ x: torch.Tensor,
231
+ self_attn_mask: torch.Tensor = None,
232
+ self_attn_padding_mask: torch.Tensor = None,
233
+ need_weights: bool = False,
234
+ pos_bias=None,
235
+ ):
236
+ residual = x
237
+
238
+ if self.layer_norm_first:
239
+ x = self.self_attn_layer_norm(x)
240
+ x, attn, pos_bias = self.self_attn(
241
+ query=x,
242
+ key=x,
243
+ value=x,
244
+ key_padding_mask=self_attn_padding_mask,
245
+ need_weights=False,
246
+ attn_mask=self_attn_mask,
247
+ position_bias=pos_bias,
248
+ )
249
+ x = self.dropout1(x)
250
+ x = residual + x
251
+
252
+ residual = x
253
+ x = self.final_layer_norm(x)
254
+ if self.activation_name == "glu":
255
+ x = self.fc1(x)
256
+ else:
257
+ x = self.activation_fn(self.fc1(x))
258
+ x = self.dropout2(x)
259
+ x = self.fc2(x)
260
+ x = self.dropout3(x)
261
+ x = residual + x
262
+ else:
263
+ x, attn, pos_bias = self.self_attn(
264
+ query=x,
265
+ key=x,
266
+ value=x,
267
+ key_padding_mask=self_attn_padding_mask,
268
+ need_weights=need_weights,
269
+ attn_mask=self_attn_mask,
270
+ position_bias=pos_bias,
271
+ )
272
+
273
+ x = self.dropout1(x)
274
+ x = residual * self.deep_norm_alpha + x
275
+
276
+ x = self.self_attn_layer_norm(x)
277
+
278
+ residual = x
279
+ if self.activation_name == "glu":
280
+ x = self.fc1(x)
281
+ else:
282
+ x = self.activation_fn(self.fc1(x))
283
+ x = self.dropout2(x)
284
+ x = self.fc2(x)
285
+ x = self.dropout3(x)
286
+ x = residual * self.deep_norm_alpha + x
287
+ x = self.final_layer_norm(x)
288
+
289
+ return x, attn, pos_bias
290
+
291
+
292
+ class MultiheadAttention(nn.Module):
293
+ """Multi-headed attention.
294
+
295
+ See "Attention Is All You Need" for more details.
296
+ """
297
+
298
+ def __init__(
299
+ self,
300
+ embed_dim,
301
+ num_heads,
302
+ kdim=None,
303
+ vdim=None,
304
+ dropout=0.0,
305
+ bias=True,
306
+ add_bias_kv=False,
307
+ add_zero_attn=False,
308
+ self_attention=False,
309
+ encoder_decoder_attention=False,
310
+ q_noise=0.0,
311
+ qn_block_size=8,
312
+ has_relative_attention_bias=False,
313
+ num_buckets=32,
314
+ max_distance=128,
315
+ gru_rel_pos=False,
316
+ rescale_init=False,
317
+ ):
318
+ super().__init__()
319
+ self.embed_dim = embed_dim
320
+ self.kdim = kdim if kdim is not None else embed_dim
321
+ self.vdim = vdim if vdim is not None else embed_dim
322
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
323
+
324
+ self.num_heads = num_heads
325
+ self.dropout_module = nn.Dropout(dropout)
326
+
327
+ self.has_relative_attention_bias = has_relative_attention_bias
328
+ self.num_buckets = num_buckets
329
+ self.max_distance = max_distance
330
+ if self.has_relative_attention_bias:
331
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
332
+
333
+ self.head_dim = embed_dim // num_heads
334
+ self.q_head_dim = self.head_dim
335
+ self.k_head_dim = self.head_dim
336
+ assert (
337
+ self.head_dim * num_heads == self.embed_dim
338
+ ), "embed_dim must be divisible by num_heads"
339
+ self.scaling = self.head_dim**-0.5
340
+
341
+ self.self_attention = self_attention
342
+ self.encoder_decoder_attention = encoder_decoder_attention
343
+
344
+ assert not self.self_attention or self.qkv_same_dim, (
345
+ "Self-attention requires query, key and " "value to be of the same size"
346
+ )
347
+
348
+ k_bias = True
349
+ if rescale_init:
350
+ k_bias = False
351
+
352
+ k_embed_dim = embed_dim
353
+ q_embed_dim = embed_dim
354
+
355
+ self.k_proj = quant_noise(
356
+ nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
357
+ )
358
+ self.v_proj = quant_noise(
359
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
360
+ )
361
+ self.q_proj = quant_noise(
362
+ nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
363
+ )
364
+
365
+ self.out_proj = quant_noise(
366
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
367
+ )
368
+
369
+ if add_bias_kv:
370
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
371
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
372
+ else:
373
+ self.bias_k = self.bias_v = None
374
+
375
+ self.add_zero_attn = add_zero_attn
376
+
377
+ self.gru_rel_pos = gru_rel_pos
378
+ if self.gru_rel_pos:
379
+ self.grep_linear = nn.Linear(self.q_head_dim, 8)
380
+ self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
381
+
382
+ self.reset_parameters()
383
+
384
+ def reset_parameters(self):
385
+ if self.qkv_same_dim:
386
+ # Empirically observed the convergence to be much better with
387
+ # the scaled initialization
388
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
389
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
390
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
391
+ else:
392
+ nn.init.xavier_uniform_(self.k_proj.weight)
393
+ nn.init.xavier_uniform_(self.v_proj.weight)
394
+ nn.init.xavier_uniform_(self.q_proj.weight)
395
+
396
+ nn.init.xavier_uniform_(self.out_proj.weight)
397
+ if self.out_proj.bias is not None:
398
+ nn.init.constant_(self.out_proj.bias, 0.0)
399
+ if self.bias_k is not None:
400
+ nn.init.xavier_normal_(self.bias_k)
401
+ if self.bias_v is not None:
402
+ nn.init.xavier_normal_(self.bias_v)
403
+ if self.has_relative_attention_bias:
404
+ nn.init.xavier_normal_(self.relative_attention_bias.weight)
405
+
406
+ def _relative_positions_bucket(self, relative_positions, bidirectional=True):
407
+ num_buckets = self.num_buckets
408
+ max_distance = self.max_distance
409
+ relative_buckets = 0
410
+
411
+ if bidirectional:
412
+ num_buckets = num_buckets // 2
413
+ relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
414
+ relative_positions = torch.abs(relative_positions)
415
+ else:
416
+ relative_positions = -torch.min(
417
+ relative_positions, torch.zeros_like(relative_positions)
418
+ )
419
+
420
+ max_exact = num_buckets // 2
421
+ is_small = relative_positions < max_exact
422
+
423
+ relative_postion_if_large = max_exact + (
424
+ torch.log(relative_positions.float() / max_exact)
425
+ / math.log(max_distance / max_exact)
426
+ * (num_buckets - max_exact)
427
+ ).to(torch.long)
428
+ relative_postion_if_large = torch.min(
429
+ relative_postion_if_large,
430
+ torch.full_like(relative_postion_if_large, num_buckets - 1),
431
+ )
432
+
433
+ relative_buckets += torch.where(
434
+ is_small, relative_positions, relative_postion_if_large
435
+ )
436
+ return relative_buckets
437
+
438
+ def compute_bias(self, query_length, key_length):
439
+ context_position = torch.arange(query_length, dtype=torch.long)[:, None]
440
+ memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
441
+ relative_position = memory_position - context_position
442
+ relative_position_bucket = self._relative_positions_bucket(
443
+ relative_position, bidirectional=True
444
+ )
445
+ relative_position_bucket = relative_position_bucket.to(
446
+ self.relative_attention_bias.weight.device
447
+ )
448
+ values = self.relative_attention_bias(relative_position_bucket)
449
+ values = values.permute([2, 0, 1])
450
+ return values
451
+
452
+ def forward(
453
+ self,
454
+ query,
455
+ key: Optional[Tensor],
456
+ value: Optional[Tensor],
457
+ key_padding_mask: Optional[Tensor] = None,
458
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
459
+ need_weights: bool = True,
460
+ static_kv: bool = False,
461
+ attn_mask: Optional[Tensor] = None,
462
+ before_softmax: bool = False,
463
+ need_head_weights: bool = False,
464
+ position_bias: Optional[Tensor] = None,
465
+ ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
466
+ """Input shape: Time x Batch x Channel
467
+
468
+ Args:
469
+ key_padding_mask (ByteTensor, optional): mask to exclude
470
+ keys that are pads, of shape `(batch, src_len)`, where
471
+ padding elements are indicated by 1s.
472
+ need_weights (bool, optional): return the attention weights,
473
+ averaged over heads (default: False).
474
+ attn_mask (ByteTensor, optional): typically used to
475
+ implement causal attention, where the mask prevents the
476
+ attention from looking forward in time (default: None).
477
+ before_softmax (bool, optional): return the raw attention
478
+ weights and values before the attention softmax.
479
+ need_head_weights (bool, optional): return the attention
480
+ weights for each head. Implies *need_weights*. Default:
481
+ return the average attention weights over all heads.
482
+ """
483
+ if need_head_weights:
484
+ need_weights = True
485
+
486
+ is_tpu = query.device.type == "xla"
487
+
488
+ tgt_len, bsz, embed_dim = query.size()
489
+ src_len = tgt_len
490
+ assert embed_dim == self.embed_dim
491
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
492
+ if key is not None:
493
+ src_len, key_bsz, _ = key.size()
494
+ if not torch.jit.is_scripting():
495
+ assert key_bsz == bsz
496
+ assert value is not None
497
+ assert src_len, bsz == value.shape[:2]
498
+
499
+ if self.has_relative_attention_bias and position_bias is None:
500
+ position_bias = self.compute_bias(tgt_len, src_len)
501
+ position_bias = (
502
+ position_bias.unsqueeze(0)
503
+ .repeat(bsz, 1, 1, 1)
504
+ .view(bsz * self.num_heads, tgt_len, src_len)
505
+ )
506
+
507
+ if incremental_state is not None:
508
+ saved_state = self._get_input_buffer(incremental_state)
509
+ if saved_state is not None and "prev_key" in saved_state:
510
+ # previous time steps are cached - no need to recompute
511
+ # key and value if they are static
512
+ if static_kv:
513
+ assert self.encoder_decoder_attention and not self.self_attention
514
+ key = value = None
515
+ else:
516
+ saved_state = None
517
+
518
+ if self.self_attention:
519
+ q = self.q_proj(query)
520
+ k = self.k_proj(query)
521
+ v = self.v_proj(query)
522
+ elif self.encoder_decoder_attention:
523
+ # encoder-decoder attention
524
+ q = self.q_proj(query)
525
+ if key is None:
526
+ assert value is None
527
+ k = v = None
528
+ else:
529
+ k = self.k_proj(key)
530
+ v = self.v_proj(key)
531
+
532
+ else:
533
+ assert key is not None and value is not None
534
+ q = self.q_proj(query)
535
+ k = self.k_proj(key)
536
+ v = self.v_proj(value)
537
+ q *= self.scaling
538
+ alpha = 32
539
+ q *= 1 / alpha
540
+
541
+ if self.bias_k is not None:
542
+ assert self.bias_v is not None
543
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
544
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
545
+ if attn_mask is not None:
546
+ attn_mask = torch.cat(
547
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
548
+ )
549
+ if key_padding_mask is not None:
550
+ key_padding_mask = torch.cat(
551
+ [
552
+ key_padding_mask,
553
+ key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
554
+ ],
555
+ dim=1,
556
+ )
557
+
558
+ q = (
559
+ q.contiguous()
560
+ .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
561
+ .transpose(0, 1)
562
+ )
563
+ if k is not None:
564
+ k = (
565
+ k.contiguous()
566
+ .view(-1, bsz * self.num_heads, self.k_head_dim)
567
+ .transpose(0, 1)
568
+ )
569
+ if v is not None:
570
+ v = (
571
+ v.contiguous()
572
+ .view(-1, bsz * self.num_heads, self.head_dim)
573
+ .transpose(0, 1)
574
+ )
575
+
576
+ if saved_state is not None:
577
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
578
+ if "prev_key" in saved_state:
579
+ _prev_key = saved_state["prev_key"]
580
+ assert _prev_key is not None
581
+ prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
582
+ if static_kv:
583
+ k = prev_key
584
+ else:
585
+ assert k is not None
586
+ k = torch.cat([prev_key, k], dim=1)
587
+ src_len = k.size(1)
588
+ if "prev_value" in saved_state:
589
+ _prev_value = saved_state["prev_value"]
590
+ assert _prev_value is not None
591
+ prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
592
+ if static_kv:
593
+ v = prev_value
594
+ else:
595
+ assert v is not None
596
+ v = torch.cat([prev_value, v], dim=1)
597
+ prev_key_padding_mask: Optional[Tensor] = None
598
+ if "prev_key_padding_mask" in saved_state:
599
+ prev_key_padding_mask = saved_state["prev_key_padding_mask"]
600
+ assert k is not None and v is not None
601
+ key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
602
+ key_padding_mask=key_padding_mask,
603
+ prev_key_padding_mask=prev_key_padding_mask,
604
+ batch_size=bsz,
605
+ src_len=k.size(1),
606
+ static_kv=static_kv,
607
+ )
608
+
609
+ saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
610
+ saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
611
+ saved_state["prev_key_padding_mask"] = key_padding_mask
612
+ # In this branch incremental_state is never None
613
+ assert incremental_state is not None
614
+ incremental_state = self._set_input_buffer(incremental_state, saved_state)
615
+ assert k is not None
616
+ assert k.size(1) == src_len
617
+
618
+ # This is part of a workaround to get around fork/join parallelism
619
+ # not supporting Optional types.
620
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
621
+ key_padding_mask = None
622
+
623
+ if key_padding_mask is not None:
624
+ assert key_padding_mask.size(0) == bsz
625
+ assert key_padding_mask.size(1) == src_len
626
+
627
+ if self.add_zero_attn:
628
+ assert v is not None
629
+ src_len += 1
630
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
631
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
632
+ if attn_mask is not None:
633
+ attn_mask = torch.cat(
634
+ [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
635
+ )
636
+ if key_padding_mask is not None:
637
+ key_padding_mask = torch.cat(
638
+ [
639
+ key_padding_mask,
640
+ torch.zeros(key_padding_mask.size(0), 1).type_as(
641
+ key_padding_mask
642
+ ),
643
+ ],
644
+ dim=1,
645
+ )
646
+
647
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
648
+ attn_weights = (
649
+ attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]
650
+ ) * alpha
651
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
652
+
653
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
654
+
655
+ if attn_mask is not None:
656
+ attn_mask = attn_mask.unsqueeze(0)
657
+ attn_weights += attn_mask
658
+
659
+ if key_padding_mask is not None:
660
+ # don't attend to padding symbols
661
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
662
+ if not is_tpu:
663
+ attn_weights = attn_weights.masked_fill(
664
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
665
+ float("-inf"),
666
+ )
667
+ else:
668
+ attn_weights = attn_weights.transpose(0, 2)
669
+ attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
670
+ attn_weights = attn_weights.transpose(0, 2)
671
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
672
+
673
+ if before_softmax:
674
+ return attn_weights, v, position_bias
675
+
676
+ if position_bias is not None:
677
+ attn_mask_rel_pos = position_bias
678
+ if self.gru_rel_pos == 1:
679
+ query_layer = (
680
+ q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
681
+ * alpha
682
+ / self.scaling
683
+ )
684
+ _B, _H, _L, __ = query_layer.size()
685
+ gate_a, gate_b = torch.sigmoid(
686
+ self.grep_linear(query_layer)
687
+ .view(_B, _H, _L, 2, 4)
688
+ .sum(-1, keepdim=False)
689
+ ).chunk(2, dim=-1)
690
+ gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
691
+ attn_mask_rel_pos = (
692
+ gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
693
+ )
694
+
695
+ attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
696
+
697
+ attn_weights = attn_weights + attn_mask_rel_pos
698
+
699
+ attn_weights_float = F.softmax(attn_weights, dim=-1)
700
+ attn_weights = attn_weights_float.type_as(attn_weights)
701
+ attn_probs = self.dropout_module(attn_weights)
702
+
703
+ assert v is not None
704
+ attn = torch.bmm(attn_probs, v)
705
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
706
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
707
+ attn = self.out_proj(attn)
708
+ attn_weights: Optional[Tensor] = None
709
+ if need_weights:
710
+ attn_weights = attn_weights_float.view(
711
+ bsz, self.num_heads, tgt_len, src_len
712
+ ).transpose(1, 0)
713
+ if not need_head_weights:
714
+ # average attention weights over heads
715
+ attn_weights = attn_weights.mean(dim=0)
716
+
717
+ return attn, attn_weights, position_bias
718
+
719
+ @staticmethod
720
+ def _append_prev_key_padding_mask(
721
+ key_padding_mask: Optional[Tensor],
722
+ prev_key_padding_mask: Optional[Tensor],
723
+ batch_size: int,
724
+ src_len: int,
725
+ static_kv: bool,
726
+ ) -> Optional[Tensor]:
727
+ # saved key padding masks have shape (bsz, seq_len)
728
+ if prev_key_padding_mask is not None and static_kv:
729
+ new_key_padding_mask = prev_key_padding_mask
730
+ elif prev_key_padding_mask is not None and key_padding_mask is not None:
731
+ new_key_padding_mask = torch.cat(
732
+ [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
733
+ )
734
+ # During incremental decoding, as the padding token enters and
735
+ # leaves the frame, there will be a time when prev or current
736
+ # is None
737
+ elif prev_key_padding_mask is not None:
738
+ if src_len > prev_key_padding_mask.size(1):
739
+ filler = torch.zeros(
740
+ (batch_size, src_len - prev_key_padding_mask.size(1)),
741
+ device=prev_key_padding_mask.device,
742
+ )
743
+ new_key_padding_mask = torch.cat(
744
+ [prev_key_padding_mask.float(), filler.float()], dim=1
745
+ )
746
+ else:
747
+ new_key_padding_mask = prev_key_padding_mask.float()
748
+ elif key_padding_mask is not None:
749
+ if src_len > key_padding_mask.size(1):
750
+ filler = torch.zeros(
751
+ (batch_size, src_len - key_padding_mask.size(1)),
752
+ device=key_padding_mask.device,
753
+ )
754
+ new_key_padding_mask = torch.cat(
755
+ [filler.float(), key_padding_mask.float()], dim=1
756
+ )
757
+ else:
758
+ new_key_padding_mask = key_padding_mask.float()
759
+ else:
760
+ new_key_padding_mask = prev_key_padding_mask
761
+ return new_key_padding_mask
762
+
763
+ def _get_input_buffer(
764
+ self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
765
+ ) -> Dict[str, Optional[Tensor]]:
766
+ result = self.get_incremental_state(incremental_state, "attn_state")
767
+ if result is not None:
768
+ return result
769
+ else:
770
+ empty_result: Dict[str, Optional[Tensor]] = {}
771
+ return empty_result
772
+
773
+ def _set_input_buffer(
774
+ self,
775
+ incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
776
+ buffer: Dict[str, Optional[Tensor]],
777
+ ):
778
+ return self.set_incremental_state(incremental_state, "attn_state", buffer)
779
+
780
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
781
+ return attn_weights
782
+
783
+
784
+ def init_bert_params(module):
785
+ """
786
+ Initialize the weights specific to the BERT Model.
787
+ This overrides the default initializations depending on the specified arguments.
788
+ 1. If normal_init_linear_weights is set then weights of linear
789
+ layer will be initialized using the normal distribution and
790
+ bais will be set to the specified value.
791
+ 2. If normal_init_embed_weights is set then weights of embedding
792
+ layer will be initialized using the normal distribution.
793
+ 3. If normal_init_proj_weights is set then weights of
794
+ in_project_weight for MultiHeadAttention initialized using
795
+ the normal distribution (to be validated).
796
+ """
797
+
798
+ def normal_(data):
799
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
800
+ # so that the RNG is consistent with and without FSDP
801
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
802
+
803
+ if isinstance(module, nn.Linear):
804
+ normal_(module.weight.data)
805
+ if module.bias is not None:
806
+ module.bias.data.zero_()
807
+ if isinstance(module, nn.Embedding):
808
+ normal_(module.weight.data)
809
+ if module.padding_idx is not None:
810
+ module.weight.data[module.padding_idx].zero_()
811
+ if isinstance(module, MultiheadAttention):
812
+ normal_(module.q_proj.weight.data)
813
+ normal_(module.k_proj.weight.data)
814
+ normal_(module.v_proj.weight.data)
models/beats/modules.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on fairseq code bases
7
+ # https://github.com/pytorch/fairseq
8
+ # --------------------------------------------------------
9
+
10
+ import math
11
+ import warnings
12
+ import torch
13
+ from torch import Tensor, nn
14
+ import torch.nn.functional as F
15
+
16
+
17
+ class GradMultiply(torch.autograd.Function):
18
+ @staticmethod
19
+ def forward(ctx, x, scale):
20
+ ctx.scale = scale
21
+ res = x.new(x)
22
+ return res
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad):
26
+ return grad * ctx.scale, None
27
+
28
+
29
+ class SamePad(nn.Module):
30
+ def __init__(self, kernel_size, causal=False):
31
+ super().__init__()
32
+ if causal:
33
+ self.remove = kernel_size - 1
34
+ else:
35
+ self.remove = 1 if kernel_size % 2 == 0 else 0
36
+
37
+ def forward(self, x):
38
+ if self.remove > 0:
39
+ x = x[:, :, : -self.remove]
40
+ return x
41
+
42
+
43
+ class Swish(nn.Module):
44
+ def __init__(self):
45
+ super(Swish, self).__init__()
46
+ self.act = torch.nn.Sigmoid()
47
+
48
+ def forward(self, x):
49
+ return x * self.act(x)
50
+
51
+
52
+ class GLU_Linear(nn.Module):
53
+ def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
54
+ super(GLU_Linear, self).__init__()
55
+
56
+ self.glu_type = glu_type
57
+ self.output_dim = output_dim
58
+
59
+ if glu_type == "sigmoid":
60
+ self.glu_act = torch.nn.Sigmoid()
61
+ elif glu_type == "swish":
62
+ self.glu_act = Swish()
63
+ elif glu_type == "relu":
64
+ self.glu_act = torch.nn.ReLU()
65
+ elif glu_type == "gelu":
66
+ self.glu_act = torch.nn.GELU()
67
+
68
+ if bias_in_glu:
69
+ self.linear = nn.Linear(input_dim, output_dim * 2, True)
70
+ else:
71
+ self.linear = nn.Linear(input_dim, output_dim * 2, False)
72
+
73
+ def forward(self, x):
74
+ # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
75
+ x = self.linear(x)
76
+
77
+ if self.glu_type == "bilinear":
78
+ x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
79
+ else:
80
+ x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
81
+
82
+ return x
83
+
84
+
85
+ def gelu_accurate(x):
86
+ if not hasattr(gelu_accurate, "_a"):
87
+ gelu_accurate._a = math.sqrt(2 / math.pi)
88
+ return (
89
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
90
+ )
91
+
92
+
93
+ def gelu(x: torch.Tensor) -> torch.Tensor:
94
+ return torch.nn.functional.gelu(x.float()).type_as(x)
95
+
96
+
97
+ def get_activation_fn(activation: str):
98
+ """Returns the activation function corresponding to `activation`"""
99
+
100
+ if activation == "relu":
101
+ return F.relu
102
+ elif activation == "gelu":
103
+ return gelu
104
+ elif activation == "gelu_fast":
105
+ warnings.warn(
106
+ "--activation-fn=gelu_fast has been renamed to gelu_accurate"
107
+ )
108
+ return gelu_accurate
109
+ elif activation == "gelu_accurate":
110
+ return gelu_accurate
111
+ elif activation == "tanh":
112
+ return torch.tanh
113
+ elif activation == "linear":
114
+ return lambda x: x
115
+ elif activation == "glu":
116
+ return lambda x: x
117
+ else:
118
+ raise RuntimeError("--activation-fn {} not supported".format(activation))
119
+
120
+
121
+ def quant_noise(module, p, block_size):
122
+ """
123
+ Wraps modules and applies quantization noise to the weights for
124
+ subsequent quantization with Iterative Product Quantization as
125
+ described in "Training with Quantization Noise for Extreme Model Compression"
126
+
127
+ Args:
128
+ - module: nn.Module
129
+ - p: amount of Quantization Noise
130
+ - block_size: size of the blocks for subsequent quantization with iPQ
131
+
132
+ Remarks:
133
+ - Module weights must have the right sizes wrt the block size
134
+ - Only Linear, Embedding and Conv2d modules are supported for the moment
135
+ - For more detail on how to quantize by blocks with convolutional weights,
136
+ see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
137
+ - We implement the simplest form of noise here as stated in the paper
138
+ which consists in randomly dropping blocks
139
+ """
140
+
141
+ # if no quantization noise, don't register hook
142
+ if p <= 0:
143
+ return module
144
+
145
+ # supported modules
146
+ assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
147
+
148
+ # test whether module.weight has the right sizes wrt block_size
149
+ is_conv = module.weight.ndim == 4
150
+
151
+ # 2D matrix
152
+ if not is_conv:
153
+ assert (
154
+ module.weight.size(1) % block_size == 0
155
+ ), "Input features must be a multiple of block sizes"
156
+
157
+ # 4D matrix
158
+ else:
159
+ # 1x1 convolutions
160
+ if module.kernel_size == (1, 1):
161
+ assert (
162
+ module.in_channels % block_size == 0
163
+ ), "Input channels must be a multiple of block sizes"
164
+ # regular convolutions
165
+ else:
166
+ k = module.kernel_size[0] * module.kernel_size[1]
167
+ assert k % block_size == 0, "Kernel size must be a multiple of block size"
168
+
169
+ def _forward_pre_hook(mod, input):
170
+ # no noise for evaluation
171
+ if mod.training:
172
+ if not is_conv:
173
+ # gather weight and sizes
174
+ weight = mod.weight
175
+ in_features = weight.size(1)
176
+ out_features = weight.size(0)
177
+
178
+ # split weight matrix into blocks and randomly drop selected blocks
179
+ mask = torch.zeros(
180
+ in_features // block_size * out_features, device=weight.device
181
+ )
182
+ mask.bernoulli_(p)
183
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
184
+
185
+ else:
186
+ # gather weight and sizes
187
+ weight = mod.weight
188
+ in_channels = mod.in_channels
189
+ out_channels = mod.out_channels
190
+
191
+ # split weight matrix into blocks and randomly drop selected blocks
192
+ if mod.kernel_size == (1, 1):
193
+ mask = torch.zeros(
194
+ int(in_channels // block_size * out_channels),
195
+ device=weight.device,
196
+ )
197
+ mask.bernoulli_(p)
198
+ mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
199
+ else:
200
+ mask = torch.zeros(
201
+ weight.size(0), weight.size(1), device=weight.device
202
+ )
203
+ mask.bernoulli_(p)
204
+ mask = (
205
+ mask.unsqueeze(2)
206
+ .unsqueeze(3)
207
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
208
+ )
209
+
210
+ # scale weights and apply mask
211
+ mask = mask.to(
212
+ torch.bool
213
+ ) # x.bool() is not currently supported in TorchScript
214
+ s = 1 / (1 - p)
215
+ mod.weight.data = s * weight.masked_fill(mask, 0)
216
+
217
+ module.register_forward_pre_hook(_forward_pre_hook)
218
+ return module
models/beats/quantizer.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
3
+ # Github source: https://github.com/microsoft/unilm/tree/master/beats
4
+ # Copyright (c) 2022 Microsoft
5
+ # Licensed under The MIT License [see LICENSE for details]
6
+ # Based on VQGAN code bases
7
+ # https://github.com/CompVis/taming-transformers
8
+ # --------------------------------------------------------'
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.distributed as distributed
14
+
15
+ try:
16
+ from einops import rearrange, repeat
17
+ except ImportError:
18
+ pass
19
+
20
+
21
+ def l2norm(t):
22
+ return F.normalize(t, p=2, dim=-1)
23
+
24
+
25
+ def ema_inplace(moving_avg, new, decay):
26
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
27
+
28
+
29
+ def sample_vectors(samples, num):
30
+ num_samples, device = samples.shape[0], samples.device
31
+
32
+ if num_samples >= num:
33
+ indices = torch.randperm(num_samples, device=device)[:num]
34
+ else:
35
+ indices = torch.randint(0, num_samples, (num,), device=device)
36
+
37
+ return samples[indices]
38
+
39
+
40
+ def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
41
+ dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
42
+
43
+ means = sample_vectors(samples, num_clusters)
44
+
45
+ for _ in range(num_iters):
46
+ if use_cosine_sim:
47
+ dists = samples @ means.t()
48
+ else:
49
+ diffs = rearrange(samples, 'n d -> n () d') \
50
+ - rearrange(means, 'c d -> () c d')
51
+ dists = -(diffs ** 2).sum(dim=-1)
52
+
53
+ buckets = dists.max(dim=-1).indices
54
+ bins = torch.bincount(buckets, minlength=num_clusters)
55
+ zero_mask = bins == 0
56
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
57
+
58
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
59
+ new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
60
+ new_means = new_means / bins_min_clamped[..., None]
61
+
62
+ if use_cosine_sim:
63
+ new_means = l2norm(new_means)
64
+
65
+ means = torch.where(zero_mask[..., None], means, new_means)
66
+
67
+ return means, bins
68
+
69
+
70
+ class EmbeddingEMA(nn.Module):
71
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''):
72
+ super().__init__()
73
+ self.num_tokens = num_tokens
74
+ self.codebook_dim = codebook_dim
75
+ self.decay = decay
76
+ self.eps = eps
77
+ if codebook_init_path == '':
78
+ if not kmeans_init:
79
+ weight = torch.randn(num_tokens, codebook_dim)
80
+ weight = l2norm(weight)
81
+ else:
82
+ weight = torch.zeros(num_tokens, codebook_dim)
83
+ self.register_buffer('initted', torch.Tensor([not kmeans_init]))
84
+ else:
85
+ print(f"load init codebook weight from {codebook_init_path}")
86
+ codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
87
+ weight = codebook_ckpt_weight.clone()
88
+ self.register_buffer('initted', torch.Tensor([True]))
89
+
90
+ self.weight = nn.Parameter(weight, requires_grad=False)
91
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
92
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
93
+ # self.register_buffer('initted', torch.Tensor([not kmeans_init]))
94
+ self.update = True
95
+
96
+ @torch.jit.ignore
97
+ def init_embed_(self, data):
98
+ if self.initted:
99
+ return
100
+ print("Performing Kemans init for codebook")
101
+ embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
102
+ self.weight.data.copy_(embed)
103
+ self.cluster_size.data.copy_(cluster_size)
104
+ self.initted.data.copy_(torch.Tensor([True]))
105
+
106
+ def forward(self, embed_id):
107
+ return F.embedding(embed_id, self.weight)
108
+
109
+ def cluster_size_ema_update(self, new_cluster_size):
110
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
111
+
112
+ def embed_avg_ema_update(self, new_embed_avg):
113
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
114
+
115
+ def weight_update(self, num_tokens):
116
+ n = self.cluster_size.sum()
117
+ smoothed_cluster_size = (
118
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
119
+ )
120
+ # normalize embedding average with smoothed cluster size
121
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
122
+ # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
123
+ self.weight.data.copy_(embed_normalized)
124
+
125
+
126
+ def norm_ema_inplace(moving_avg, new, decay):
127
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
128
+ moving_avg.data.copy_(l2norm(moving_avg.data))
129
+
130
+
131
+ class NormEMAVectorQuantizer(nn.Module):
132
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
133
+ statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
134
+ super().__init__()
135
+ self.codebook_dim = embedding_dim
136
+ self.num_tokens = n_embed
137
+ self.beta = beta
138
+ self.decay = decay
139
+
140
+ # learnable = True if orthogonal_reg_weight > 0 else False
141
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
142
+
143
+ self.statistic_code_usage = statistic_code_usage
144
+ if statistic_code_usage:
145
+ self.register_buffer('cluster_size', torch.zeros(n_embed))
146
+ if distributed.is_available() and distributed.is_initialized():
147
+ print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
148
+ self.all_reduce_fn = distributed.all_reduce
149
+ else:
150
+ self.all_reduce_fn = nn.Identity()
151
+
152
+ def reset_cluster_size(self, device):
153
+ if self.statistic_code_usage:
154
+ self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
155
+ self.cluster_size = self.cluster_size.to(device)
156
+
157
+ def forward(self, z):
158
+ # reshape z -> (batch, height, width, channel) and flatten
159
+ # z, 'b c h w -> b h w c'
160
+ # z = rearrange(z, 'b c h w -> b h w c')
161
+ # z = z.transpose(1, 2)
162
+ z = l2norm(z)
163
+ z_flattened = z.reshape(-1, self.codebook_dim)
164
+
165
+ self.embedding.init_embed_(z_flattened)
166
+
167
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
168
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
169
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
170
+
171
+ encoding_indices = torch.argmin(d, dim=1)
172
+
173
+ z_q = self.embedding(encoding_indices).view(z.shape)
174
+
175
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
176
+
177
+ if not self.training:
178
+ with torch.no_grad():
179
+ cluster_size = encodings.sum(0)
180
+ self.all_reduce_fn(cluster_size)
181
+ ema_inplace(self.cluster_size, cluster_size, self.decay)
182
+
183
+ if self.training and self.embedding.update:
184
+ # EMA cluster size
185
+
186
+ bins = encodings.sum(0)
187
+ self.all_reduce_fn(bins)
188
+
189
+ # self.embedding.cluster_size_ema_update(bins)
190
+ ema_inplace(self.cluster_size, bins, self.decay)
191
+
192
+ zero_mask = (bins == 0)
193
+ bins = bins.masked_fill(zero_mask, 1.)
194
+
195
+ embed_sum = z_flattened.t() @ encodings
196
+ self.all_reduce_fn(embed_sum)
197
+
198
+ embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
199
+ embed_normalized = l2norm(embed_normalized)
200
+
201
+ embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
202
+ embed_normalized)
203
+ norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
204
+
205
+ # compute loss for embedding
206
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
207
+
208
+ # preserve gradients
209
+ z_q = z + (z_q - z).detach()
210
+
211
+ # reshape back to match original input shape
212
+ # z_q, 'b h w c -> b c h w'
213
+ # z_q = rearrange(z_q, 'b h w c -> b c h w')
214
+ # z_q = z_q.transpose(1, 2)
215
+ return z_q, loss, encoding_indices
models/qformer/LICENSE_Lavis ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2022 Salesforce, Inc.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9
+
10
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11
+
12
+ 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
models/qformer/LICENSE_MiniGPT4 ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright 2023 Deyao Zhu
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9
+
10
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11
+
12
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
models/qformer/LICENSE_VideoLlama ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2023, Multilingual NLP Team at Alibaba DAMO Academy
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
models/qformer/Qformer.py ADDED
@@ -0,0 +1,1217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from salesforce@LAVIS. Below is the original copyright:
3
+ * Copyright (c) 2023, salesforce.com, inc.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7
+ * By Junnan Li
8
+ * Based on huggingface code base
9
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
10
+ """
11
+
12
+ import math
13
+ import os
14
+ import warnings
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Dict, Any
17
+
18
+ import torch
19
+ from torch import Tensor, device, dtype, nn
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+ from torch.nn import CrossEntropyLoss
23
+ import torch.nn.functional as F
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.file_utils import (
27
+ ModelOutput,
28
+ )
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ BaseModelOutputWithPoolingAndCrossAttentions,
32
+ CausalLMOutputWithCrossAttentions,
33
+ MaskedLMOutput,
34
+ MultipleChoiceModelOutput,
35
+ NextSentencePredictorOutput,
36
+ QuestionAnsweringModelOutput,
37
+ SequenceClassifierOutput,
38
+ TokenClassifierOutput,
39
+ )
40
+ from transformers.modeling_utils import (
41
+ PreTrainedModel,
42
+ apply_chunking_to_forward,
43
+ find_pruneable_heads_and_indices,
44
+ prune_linear_layer,
45
+ )
46
+ from transformers.utils import logging
47
+ from transformers.models.bert.configuration_bert import BertConfig
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(
58
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
59
+ )
60
+ self.position_embeddings = nn.Embedding(
61
+ config.max_position_embeddings, config.hidden_size
62
+ )
63
+
64
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
65
+ # any TensorFlow checkpoint file
66
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
67
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
68
+
69
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
70
+ self.register_buffer(
71
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
72
+ )
73
+ self.position_embedding_type = getattr(
74
+ config, "position_embedding_type", "absolute"
75
+ )
76
+
77
+ self.config = config
78
+
79
+ def forward(
80
+ self,
81
+ input_ids=None,
82
+ position_ids=None,
83
+ query_embeds=None,
84
+ past_key_values_length=0,
85
+ ):
86
+ if input_ids is not None:
87
+ seq_length = input_ids.size()[1]
88
+ else:
89
+ seq_length = 0
90
+
91
+ if position_ids is None:
92
+ position_ids = self.position_ids[
93
+ :, past_key_values_length : seq_length + past_key_values_length
94
+ ].clone()
95
+
96
+ if input_ids is not None:
97
+ embeddings = self.word_embeddings(input_ids)
98
+ if self.position_embedding_type == "absolute":
99
+ position_embeddings = self.position_embeddings(position_ids)
100
+ embeddings = embeddings + position_embeddings
101
+
102
+ if query_embeds is not None:
103
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
104
+ else:
105
+ embeddings = query_embeds
106
+
107
+ embeddings = self.LayerNorm(embeddings)
108
+ embeddings = self.dropout(embeddings)
109
+ return embeddings
110
+
111
+
112
+ class BertSelfAttention(nn.Module):
113
+ def __init__(self, config, is_cross_attention):
114
+ super().__init__()
115
+ self.config = config
116
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
117
+ config, "embedding_size"
118
+ ):
119
+ raise ValueError(
120
+ "The hidden size (%d) is not a multiple of the number of attention "
121
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
122
+ )
123
+
124
+ self.num_attention_heads = config.num_attention_heads
125
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
126
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
127
+
128
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
129
+ if is_cross_attention:
130
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
131
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
132
+ else:
133
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
134
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
135
+
136
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
137
+ self.position_embedding_type = getattr(
138
+ config, "position_embedding_type", "absolute"
139
+ )
140
+ if (
141
+ self.position_embedding_type == "relative_key"
142
+ or self.position_embedding_type == "relative_key_query"
143
+ ):
144
+ self.max_position_embeddings = config.max_position_embeddings
145
+ self.distance_embedding = nn.Embedding(
146
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
147
+ )
148
+ self.save_attention = False
149
+
150
+ def save_attn_gradients(self, attn_gradients):
151
+ self.attn_gradients = attn_gradients
152
+
153
+ def get_attn_gradients(self):
154
+ return self.attn_gradients
155
+
156
+ def save_attention_map(self, attention_map):
157
+ self.attention_map = attention_map
158
+
159
+ def get_attention_map(self):
160
+ return self.attention_map
161
+
162
+ def transpose_for_scores(self, x):
163
+ new_x_shape = x.size()[:-1] + (
164
+ self.num_attention_heads,
165
+ self.attention_head_size,
166
+ )
167
+ x = x.view(*new_x_shape)
168
+ return x.permute(0, 2, 1, 3)
169
+
170
+ def forward(
171
+ self,
172
+ hidden_states,
173
+ attention_mask=None,
174
+ head_mask=None,
175
+ encoder_hidden_states=None,
176
+ encoder_attention_mask=None,
177
+ past_key_value=None,
178
+ output_attentions=False,
179
+ ):
180
+
181
+ # If this is instantiated as a cross-attention module, the keys
182
+ # and values come from an encoder; the attention mask needs to be
183
+ # such that the encoder's padding tokens are not attended to.
184
+ is_cross_attention = encoder_hidden_states is not None
185
+
186
+ if is_cross_attention:
187
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
188
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
189
+ attention_mask = encoder_attention_mask
190
+ elif past_key_value is not None:
191
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
192
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
193
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
194
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
195
+ else:
196
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
197
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
198
+
199
+ mixed_query_layer = self.query(hidden_states)
200
+
201
+ query_layer = self.transpose_for_scores(mixed_query_layer)
202
+
203
+ past_key_value = (key_layer, value_layer)
204
+
205
+ # Take the dot product between "query" and "key" to get the raw attention scores.
206
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
207
+
208
+ if (
209
+ self.position_embedding_type == "relative_key"
210
+ or self.position_embedding_type == "relative_key_query"
211
+ ):
212
+ seq_length = hidden_states.size()[1]
213
+ position_ids_l = torch.arange(
214
+ seq_length, dtype=torch.long, device=hidden_states.device
215
+ ).view(-1, 1)
216
+ position_ids_r = torch.arange(
217
+ seq_length, dtype=torch.long, device=hidden_states.device
218
+ ).view(1, -1)
219
+ distance = position_ids_l - position_ids_r
220
+ positional_embedding = self.distance_embedding(
221
+ distance + self.max_position_embeddings - 1
222
+ )
223
+ positional_embedding = positional_embedding.to(
224
+ dtype=query_layer.dtype
225
+ ) # fp16 compatibility
226
+
227
+ if self.position_embedding_type == "relative_key":
228
+ relative_position_scores = torch.einsum(
229
+ "bhld,lrd->bhlr", query_layer, positional_embedding
230
+ )
231
+ attention_scores = attention_scores + relative_position_scores
232
+ elif self.position_embedding_type == "relative_key_query":
233
+ relative_position_scores_query = torch.einsum(
234
+ "bhld,lrd->bhlr", query_layer, positional_embedding
235
+ )
236
+ relative_position_scores_key = torch.einsum(
237
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
238
+ )
239
+ attention_scores = (
240
+ attention_scores
241
+ + relative_position_scores_query
242
+ + relative_position_scores_key
243
+ )
244
+
245
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
246
+ if attention_mask is not None:
247
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
248
+ attention_scores = attention_scores + attention_mask
249
+
250
+ # Normalize the attention scores to probabilities.
251
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
252
+
253
+ if is_cross_attention and self.save_attention:
254
+ self.save_attention_map(attention_probs)
255
+ attention_probs.register_hook(self.save_attn_gradients)
256
+
257
+ # This is actually dropping out entire tokens to attend to, which might
258
+ # seem a bit unusual, but is taken from the original Transformer paper.
259
+ attention_probs_dropped = self.dropout(attention_probs)
260
+
261
+ # Mask heads if we want to
262
+ if head_mask is not None:
263
+ attention_probs_dropped = attention_probs_dropped * head_mask
264
+
265
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
266
+
267
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
268
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
269
+ context_layer = context_layer.view(*new_context_layer_shape)
270
+
271
+ outputs = (
272
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
273
+ )
274
+
275
+ outputs = outputs + (past_key_value,)
276
+ return outputs
277
+
278
+
279
+ class BertSelfOutput(nn.Module):
280
+ def __init__(self, config):
281
+ super().__init__()
282
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
283
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
284
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
285
+
286
+ def forward(self, hidden_states, input_tensor):
287
+ hidden_states = self.dense(hidden_states)
288
+ hidden_states = self.dropout(hidden_states)
289
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
290
+ return hidden_states
291
+
292
+
293
+ class BertAttention(nn.Module):
294
+ def __init__(self, config, is_cross_attention=False):
295
+ super().__init__()
296
+ self.self = BertSelfAttention(config, is_cross_attention)
297
+ self.output = BertSelfOutput(config)
298
+ self.pruned_heads = set()
299
+
300
+ def prune_heads(self, heads):
301
+ if len(heads) == 0:
302
+ return
303
+ heads, index = find_pruneable_heads_and_indices(
304
+ heads,
305
+ self.self.num_attention_heads,
306
+ self.self.attention_head_size,
307
+ self.pruned_heads,
308
+ )
309
+
310
+ # Prune linear layers
311
+ self.self.query = prune_linear_layer(self.self.query, index)
312
+ self.self.key = prune_linear_layer(self.self.key, index)
313
+ self.self.value = prune_linear_layer(self.self.value, index)
314
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
315
+
316
+ # Update hyper params and store pruned heads
317
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
318
+ self.self.all_head_size = (
319
+ self.self.attention_head_size * self.self.num_attention_heads
320
+ )
321
+ self.pruned_heads = self.pruned_heads.union(heads)
322
+
323
+ def forward(
324
+ self,
325
+ hidden_states,
326
+ attention_mask=None,
327
+ head_mask=None,
328
+ encoder_hidden_states=None,
329
+ encoder_attention_mask=None,
330
+ past_key_value=None,
331
+ output_attentions=False,
332
+ ):
333
+ self_outputs = self.self(
334
+ hidden_states,
335
+ attention_mask,
336
+ head_mask,
337
+ encoder_hidden_states,
338
+ encoder_attention_mask,
339
+ past_key_value,
340
+ output_attentions,
341
+ )
342
+ attention_output = self.output(self_outputs[0], hidden_states)
343
+
344
+ outputs = (attention_output,) + self_outputs[
345
+ 1:
346
+ ] # add attentions if we output them
347
+ return outputs
348
+
349
+
350
+ class BertIntermediate(nn.Module):
351
+ def __init__(self, config):
352
+ super().__init__()
353
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
354
+ if isinstance(config.hidden_act, str):
355
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
356
+ else:
357
+ self.intermediate_act_fn = config.hidden_act
358
+
359
+ def forward(self, hidden_states):
360
+ hidden_states = self.dense(hidden_states)
361
+ hidden_states = self.intermediate_act_fn(hidden_states)
362
+ return hidden_states
363
+
364
+
365
+ class BertOutput(nn.Module):
366
+ def __init__(self, config):
367
+ super().__init__()
368
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
369
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
370
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
371
+
372
+ def forward(self, hidden_states, input_tensor):
373
+ hidden_states = self.dense(hidden_states)
374
+ hidden_states = self.dropout(hidden_states)
375
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
376
+ return hidden_states
377
+
378
+
379
+ class BertLayer(nn.Module):
380
+ def __init__(self, config, layer_num):
381
+ super().__init__()
382
+ self.config = config
383
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
384
+ self.seq_len_dim = 1
385
+ self.attention = BertAttention(config)
386
+ self.layer_num = layer_num
387
+ if (
388
+ self.config.add_cross_attention
389
+ and layer_num % self.config.cross_attention_freq == 0
390
+ ):
391
+ self.crossattention = BertAttention(
392
+ config, is_cross_attention=self.config.add_cross_attention
393
+ )
394
+ self.has_cross_attention = True
395
+ else:
396
+ self.has_cross_attention = False
397
+ self.intermediate = BertIntermediate(config)
398
+ self.output = BertOutput(config)
399
+
400
+ self.intermediate_query = BertIntermediate(config)
401
+ self.output_query = BertOutput(config)
402
+
403
+ def forward(
404
+ self,
405
+ hidden_states,
406
+ attention_mask=None,
407
+ head_mask=None,
408
+ encoder_hidden_states=None,
409
+ encoder_attention_mask=None,
410
+ past_key_value=None,
411
+ output_attentions=False,
412
+ query_length=0,
413
+ ):
414
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
415
+ self_attn_past_key_value = (
416
+ past_key_value[:2] if past_key_value is not None else None
417
+ )
418
+ self_attention_outputs = self.attention(
419
+ hidden_states,
420
+ attention_mask,
421
+ head_mask,
422
+ output_attentions=output_attentions,
423
+ past_key_value=self_attn_past_key_value,
424
+ )
425
+ attention_output = self_attention_outputs[0]
426
+ outputs = self_attention_outputs[1:-1]
427
+
428
+ present_key_value = self_attention_outputs[-1]
429
+
430
+ if query_length > 0:
431
+ query_attention_output = attention_output[:, :query_length, :]
432
+
433
+ if self.has_cross_attention:
434
+ assert (
435
+ encoder_hidden_states is not None
436
+ ), "encoder_hidden_states must be given for cross-attention layers"
437
+ cross_attention_outputs = self.crossattention(
438
+ query_attention_output,
439
+ attention_mask,
440
+ head_mask,
441
+ encoder_hidden_states,
442
+ encoder_attention_mask,
443
+ output_attentions=output_attentions,
444
+ )
445
+ query_attention_output = cross_attention_outputs[0]
446
+ outputs = (
447
+ outputs + cross_attention_outputs[1:-1]
448
+ ) # add cross attentions if we output attention weights
449
+
450
+ layer_output = apply_chunking_to_forward(
451
+ self.feed_forward_chunk_query,
452
+ self.chunk_size_feed_forward,
453
+ self.seq_len_dim,
454
+ query_attention_output,
455
+ )
456
+ if attention_output.shape[1] > query_length:
457
+ layer_output_text = apply_chunking_to_forward(
458
+ self.feed_forward_chunk,
459
+ self.chunk_size_feed_forward,
460
+ self.seq_len_dim,
461
+ attention_output[:, query_length:, :],
462
+ )
463
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
464
+ else:
465
+ layer_output = apply_chunking_to_forward(
466
+ self.feed_forward_chunk,
467
+ self.chunk_size_feed_forward,
468
+ self.seq_len_dim,
469
+ attention_output,
470
+ )
471
+ outputs = (layer_output,) + outputs
472
+
473
+ outputs = outputs + (present_key_value,)
474
+
475
+ return outputs
476
+
477
+ def feed_forward_chunk(self, attention_output):
478
+ intermediate_output = self.intermediate(attention_output)
479
+ layer_output = self.output(intermediate_output, attention_output)
480
+ return layer_output
481
+
482
+ def feed_forward_chunk_query(self, attention_output):
483
+ intermediate_output = self.intermediate_query(attention_output)
484
+ layer_output = self.output_query(intermediate_output, attention_output)
485
+ return layer_output
486
+
487
+
488
+ class BertEncoder(nn.Module):
489
+ def __init__(self, config):
490
+ super().__init__()
491
+ self.config = config
492
+ self.layer = nn.ModuleList(
493
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
494
+ )
495
+
496
+ def forward(
497
+ self,
498
+ hidden_states,
499
+ attention_mask=None,
500
+ head_mask=None,
501
+ encoder_hidden_states=None,
502
+ encoder_attention_mask=None,
503
+ past_key_values=None,
504
+ use_cache=None,
505
+ output_attentions=False,
506
+ output_hidden_states=False,
507
+ return_dict=True,
508
+ query_length=0,
509
+ ):
510
+ all_hidden_states = () if output_hidden_states else None
511
+ all_self_attentions = () if output_attentions else None
512
+ all_cross_attentions = (
513
+ () if output_attentions and self.config.add_cross_attention else None
514
+ )
515
+
516
+ next_decoder_cache = () if use_cache else None
517
+
518
+ for i in range(self.config.num_hidden_layers):
519
+ layer_module = self.layer[i]
520
+ if output_hidden_states:
521
+ all_hidden_states = all_hidden_states + (hidden_states,)
522
+
523
+ layer_head_mask = head_mask[i] if head_mask is not None else None
524
+ past_key_value = past_key_values[i] if past_key_values is not None else None
525
+
526
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
527
+
528
+ if use_cache:
529
+ logger.warn(
530
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
531
+ )
532
+ use_cache = False
533
+
534
+ def create_custom_forward(module):
535
+ def custom_forward(*inputs):
536
+ return module(
537
+ *inputs, past_key_value, output_attentions, query_length
538
+ )
539
+
540
+ return custom_forward
541
+
542
+ layer_outputs = torch.utils.checkpoint.checkpoint(
543
+ create_custom_forward(layer_module),
544
+ hidden_states,
545
+ attention_mask,
546
+ layer_head_mask,
547
+ encoder_hidden_states,
548
+ encoder_attention_mask,
549
+ )
550
+ else:
551
+ layer_outputs = layer_module(
552
+ hidden_states,
553
+ attention_mask,
554
+ layer_head_mask,
555
+ encoder_hidden_states,
556
+ encoder_attention_mask,
557
+ past_key_value,
558
+ output_attentions,
559
+ query_length,
560
+ )
561
+
562
+ hidden_states = layer_outputs[0]
563
+ if use_cache:
564
+ next_decoder_cache += (layer_outputs[-1],)
565
+ if output_attentions:
566
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
567
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
568
+
569
+ if output_hidden_states:
570
+ all_hidden_states = all_hidden_states + (hidden_states,)
571
+
572
+ if not return_dict:
573
+ return tuple(
574
+ v
575
+ for v in [
576
+ hidden_states,
577
+ next_decoder_cache,
578
+ all_hidden_states,
579
+ all_self_attentions,
580
+ all_cross_attentions,
581
+ ]
582
+ if v is not None
583
+ )
584
+ return BaseModelOutputWithPastAndCrossAttentions(
585
+ last_hidden_state=hidden_states,
586
+ past_key_values=next_decoder_cache,
587
+ hidden_states=all_hidden_states,
588
+ attentions=all_self_attentions,
589
+ cross_attentions=all_cross_attentions,
590
+ )
591
+
592
+
593
+ class BertPooler(nn.Module):
594
+ def __init__(self, config):
595
+ super().__init__()
596
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
597
+ self.activation = nn.Tanh()
598
+
599
+ def forward(self, hidden_states):
600
+ # We "pool" the model by simply taking the hidden state corresponding
601
+ # to the first token.
602
+ first_token_tensor = hidden_states[:, 0]
603
+ pooled_output = self.dense(first_token_tensor)
604
+ pooled_output = self.activation(pooled_output)
605
+ return pooled_output
606
+
607
+
608
+ class BertPredictionHeadTransform(nn.Module):
609
+ def __init__(self, config):
610
+ super().__init__()
611
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
612
+ if isinstance(config.hidden_act, str):
613
+ self.transform_act_fn = ACT2FN[config.hidden_act]
614
+ else:
615
+ self.transform_act_fn = config.hidden_act
616
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
617
+
618
+ def forward(self, hidden_states):
619
+ hidden_states = self.dense(hidden_states)
620
+ hidden_states = self.transform_act_fn(hidden_states)
621
+ hidden_states = self.LayerNorm(hidden_states)
622
+ return hidden_states
623
+
624
+
625
+ class BertLMPredictionHead(nn.Module):
626
+ def __init__(self, config):
627
+ super().__init__()
628
+ self.transform = BertPredictionHeadTransform(config)
629
+
630
+ # The output weights are the same as the input embeddings, but there is
631
+ # an output-only bias for each token.
632
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
633
+
634
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
635
+
636
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
637
+ self.decoder.bias = self.bias
638
+
639
+ def forward(self, hidden_states):
640
+ hidden_states = self.transform(hidden_states)
641
+ hidden_states = self.decoder(hidden_states)
642
+ return hidden_states
643
+
644
+
645
+ class BertOnlyMLMHead(nn.Module):
646
+ def __init__(self, config):
647
+ super().__init__()
648
+ self.predictions = BertLMPredictionHead(config)
649
+
650
+ def forward(self, sequence_output):
651
+ prediction_scores = self.predictions(sequence_output)
652
+ return prediction_scores
653
+
654
+
655
+ class BertPreTrainedModel(PreTrainedModel):
656
+ """
657
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
658
+ models.
659
+ """
660
+
661
+ config_class = BertConfig
662
+ base_model_prefix = "bert"
663
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
664
+
665
+ def _init_weights(self, module):
666
+ """Initialize the weights"""
667
+ if isinstance(module, (nn.Linear, nn.Embedding)):
668
+ # Slightly different from the TF version which uses truncated_normal for initialization
669
+ # cf https://github.com/pytorch/pytorch/pull/5617
670
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
671
+ elif isinstance(module, nn.LayerNorm):
672
+ module.bias.data.zero_()
673
+ module.weight.data.fill_(1.0)
674
+ if isinstance(module, nn.Linear) and module.bias is not None:
675
+ module.bias.data.zero_()
676
+
677
+
678
+ class BertModel(BertPreTrainedModel):
679
+ """
680
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
681
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
682
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
683
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
684
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
685
+ input to the forward pass.
686
+ """
687
+
688
+ def __init__(self, config, add_pooling_layer=False):
689
+ super().__init__(config)
690
+ self.config = config
691
+
692
+ self.embeddings = BertEmbeddings(config)
693
+
694
+ self.encoder = BertEncoder(config)
695
+
696
+ self.pooler = BertPooler(config) if add_pooling_layer else None
697
+
698
+ self.init_weights()
699
+
700
+ def get_input_embeddings(self):
701
+ return self.embeddings.word_embeddings
702
+
703
+ def set_input_embeddings(self, value):
704
+ self.embeddings.word_embeddings = value
705
+
706
+ def _prune_heads(self, heads_to_prune):
707
+ """
708
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
709
+ class PreTrainedModel
710
+ """
711
+ for layer, heads in heads_to_prune.items():
712
+ self.encoder.layer[layer].attention.prune_heads(heads)
713
+
714
+ def get_extended_attention_mask(
715
+ self,
716
+ attention_mask: Tensor,
717
+ input_shape: Tuple[int],
718
+ device: device,
719
+ is_decoder: bool,
720
+ has_query: bool = False,
721
+ ) -> Tensor:
722
+ """
723
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
724
+
725
+ Arguments:
726
+ attention_mask (:obj:`torch.Tensor`):
727
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
728
+ input_shape (:obj:`Tuple[int]`):
729
+ The shape of the input to the model.
730
+ device: (:obj:`torch.device`):
731
+ The device of the input to the model.
732
+
733
+ Returns:
734
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
735
+ """
736
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
737
+ # ourselves in which case we just need to make it broadcastable to all heads.
738
+ if attention_mask.dim() == 3:
739
+ extended_attention_mask = attention_mask[:, None, :, :]
740
+ elif attention_mask.dim() == 2:
741
+ # Provided a padding mask of dimensions [batch_size, seq_length]
742
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
743
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
744
+ if is_decoder:
745
+ batch_size, seq_length = input_shape
746
+
747
+ seq_ids = torch.arange(seq_length, device=device)
748
+ causal_mask = (
749
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
750
+ <= seq_ids[None, :, None]
751
+ )
752
+
753
+ # add a prefix ones mask to the causal mask
754
+ # causal and attention masks must have same type with pytorch version < 1.3
755
+ causal_mask = causal_mask.to(attention_mask.dtype)
756
+
757
+ if causal_mask.shape[1] < attention_mask.shape[1]:
758
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
759
+ if has_query: # UniLM style attention mask
760
+ causal_mask = torch.cat(
761
+ [
762
+ torch.zeros(
763
+ (batch_size, prefix_seq_len, seq_length),
764
+ device=device,
765
+ dtype=causal_mask.dtype,
766
+ ),
767
+ causal_mask,
768
+ ],
769
+ axis=1,
770
+ )
771
+ causal_mask = torch.cat(
772
+ [
773
+ torch.ones(
774
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
775
+ device=device,
776
+ dtype=causal_mask.dtype,
777
+ ),
778
+ causal_mask,
779
+ ],
780
+ axis=-1,
781
+ )
782
+ extended_attention_mask = (
783
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
784
+ )
785
+ else:
786
+ extended_attention_mask = attention_mask[:, None, None, :]
787
+ else:
788
+ raise ValueError(
789
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
790
+ input_shape, attention_mask.shape
791
+ )
792
+ )
793
+
794
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
795
+ # masked positions, this operation will create a tensor which is 0.0 for
796
+ # positions we want to attend and -10000.0 for masked positions.
797
+ # Since we are adding it to the raw scores before the softmax, this is
798
+ # effectively the same as removing these entirely.
799
+ extended_attention_mask = extended_attention_mask.to(
800
+ dtype=self.dtype
801
+ ) # fp16 compatibility
802
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
803
+ return extended_attention_mask
804
+
805
+ def forward(
806
+ self,
807
+ input_ids=None,
808
+ attention_mask=None,
809
+ position_ids=None,
810
+ head_mask=None,
811
+ query_embeds=None,
812
+ encoder_hidden_states=None,
813
+ encoder_attention_mask=None,
814
+ past_key_values=None,
815
+ use_cache=None,
816
+ output_attentions=None,
817
+ output_hidden_states=None,
818
+ return_dict=None,
819
+ is_decoder=False,
820
+ ):
821
+ r"""
822
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
823
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
824
+ the model is configured as a decoder.
825
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
826
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
827
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
828
+ - 1 for tokens that are **not masked**,
829
+ - 0 for tokens that are **masked**.
830
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
831
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
832
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
833
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
834
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
835
+ use_cache (:obj:`bool`, `optional`):
836
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
837
+ decoding (see :obj:`past_key_values`).
838
+ """
839
+ output_attentions = (
840
+ output_attentions
841
+ if output_attentions is not None
842
+ else self.config.output_attentions
843
+ )
844
+ output_hidden_states = (
845
+ output_hidden_states
846
+ if output_hidden_states is not None
847
+ else self.config.output_hidden_states
848
+ )
849
+ return_dict = (
850
+ return_dict if return_dict is not None else self.config.use_return_dict
851
+ )
852
+
853
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
854
+
855
+ if input_ids is None:
856
+ assert (
857
+ query_embeds is not None
858
+ ), "You have to specify query_embeds when input_ids is None"
859
+
860
+ # past_key_values_length
861
+ past_key_values_length = (
862
+ past_key_values[0][0].shape[2] - self.config.query_length
863
+ if past_key_values is not None
864
+ else 0
865
+ )
866
+
867
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
868
+
869
+ embedding_output = self.embeddings(
870
+ input_ids=input_ids,
871
+ position_ids=position_ids,
872
+ query_embeds=query_embeds,
873
+ past_key_values_length=past_key_values_length,
874
+ )
875
+
876
+ input_shape = embedding_output.size()[:-1]
877
+ batch_size, seq_length = input_shape
878
+ device = embedding_output.device
879
+
880
+ if attention_mask is None:
881
+ attention_mask = torch.ones(
882
+ ((batch_size, seq_length + past_key_values_length)), device=device
883
+ )
884
+
885
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
886
+ # ourselves in which case we just need to make it broadcastable to all heads.
887
+ if is_decoder:
888
+ extended_attention_mask = self.get_extended_attention_mask(
889
+ attention_mask,
890
+ input_ids.shape,
891
+ device,
892
+ is_decoder,
893
+ has_query=(query_embeds is not None),
894
+ )
895
+ else:
896
+ extended_attention_mask = self.get_extended_attention_mask(
897
+ attention_mask, input_shape, device, is_decoder
898
+ )
899
+
900
+ # If a 2D or 3D attention mask is provided for the cross-attention
901
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
902
+ if encoder_hidden_states is not None:
903
+ if type(encoder_hidden_states) == list:
904
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
905
+ 0
906
+ ].size()
907
+ else:
908
+ (
909
+ encoder_batch_size,
910
+ encoder_sequence_length,
911
+ _,
912
+ ) = encoder_hidden_states.size()
913
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
914
+
915
+ if type(encoder_attention_mask) == list:
916
+ encoder_extended_attention_mask = [
917
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
918
+ ]
919
+ elif encoder_attention_mask is None:
920
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
921
+ encoder_extended_attention_mask = self.invert_attention_mask(
922
+ encoder_attention_mask
923
+ )
924
+ else:
925
+ encoder_extended_attention_mask = self.invert_attention_mask(
926
+ encoder_attention_mask
927
+ )
928
+ else:
929
+ encoder_extended_attention_mask = None
930
+
931
+ # Prepare head mask if needed
932
+ # 1.0 in head_mask indicate we keep the head
933
+ # attention_probs has shape bsz x n_heads x N x N
934
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
935
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
936
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
937
+
938
+ encoder_outputs = self.encoder(
939
+ embedding_output,
940
+ attention_mask=extended_attention_mask,
941
+ head_mask=head_mask,
942
+ encoder_hidden_states=encoder_hidden_states,
943
+ encoder_attention_mask=encoder_extended_attention_mask,
944
+ past_key_values=past_key_values,
945
+ use_cache=use_cache,
946
+ output_attentions=output_attentions,
947
+ output_hidden_states=output_hidden_states,
948
+ return_dict=return_dict,
949
+ query_length=query_length,
950
+ )
951
+ sequence_output = encoder_outputs[0]
952
+ pooled_output = (
953
+ self.pooler(sequence_output) if self.pooler is not None else None
954
+ )
955
+
956
+ if not return_dict:
957
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
958
+
959
+ return BaseModelOutputWithPoolingAndCrossAttentions(
960
+ last_hidden_state=sequence_output,
961
+ pooler_output=pooled_output,
962
+ past_key_values=encoder_outputs.past_key_values,
963
+ hidden_states=encoder_outputs.hidden_states,
964
+ attentions=encoder_outputs.attentions,
965
+ cross_attentions=encoder_outputs.cross_attentions,
966
+ )
967
+
968
+
969
+ class BertLMHeadModel(BertPreTrainedModel):
970
+
971
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
972
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
973
+
974
+ def __init__(self, config):
975
+ super().__init__(config)
976
+
977
+ self.bert = BertModel(config, add_pooling_layer=False)
978
+ self.cls = BertOnlyMLMHead(config)
979
+
980
+ self.init_weights()
981
+
982
+ def get_output_embeddings(self):
983
+ return self.cls.predictions.decoder
984
+
985
+ def set_output_embeddings(self, new_embeddings):
986
+ self.cls.predictions.decoder = new_embeddings
987
+
988
+ def forward(
989
+ self,
990
+ input_ids=None,
991
+ attention_mask=None,
992
+ position_ids=None,
993
+ head_mask=None,
994
+ query_embeds=None,
995
+ encoder_hidden_states=None,
996
+ encoder_attention_mask=None,
997
+ labels=None,
998
+ past_key_values=None,
999
+ use_cache=True,
1000
+ output_attentions=None,
1001
+ output_hidden_states=None,
1002
+ return_dict=None,
1003
+ return_logits=False,
1004
+ is_decoder=True,
1005
+ reduction="mean",
1006
+ ):
1007
+ r"""
1008
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1009
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1010
+ the model is configured as a decoder.
1011
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1012
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1013
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1014
+ - 1 for tokens that are **not masked**,
1015
+ - 0 for tokens that are **masked**.
1016
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1017
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1018
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1019
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1020
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1021
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1022
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1023
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1024
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1025
+ use_cache (:obj:`bool`, `optional`):
1026
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1027
+ decoding (see :obj:`past_key_values`).
1028
+ Returns:
1029
+ Example::
1030
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1031
+ >>> import torch
1032
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1033
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1034
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1035
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1036
+ >>> outputs = model(**inputs)
1037
+ >>> prediction_logits = outputs.logits
1038
+ """
1039
+ return_dict = (
1040
+ return_dict if return_dict is not None else self.config.use_return_dict
1041
+ )
1042
+ if labels is not None:
1043
+ use_cache = False
1044
+ if past_key_values is not None:
1045
+ query_embeds = None
1046
+
1047
+ outputs = self.bert(
1048
+ input_ids,
1049
+ attention_mask=attention_mask,
1050
+ position_ids=position_ids,
1051
+ head_mask=head_mask,
1052
+ query_embeds=query_embeds,
1053
+ encoder_hidden_states=encoder_hidden_states,
1054
+ encoder_attention_mask=encoder_attention_mask,
1055
+ past_key_values=past_key_values,
1056
+ use_cache=use_cache,
1057
+ output_attentions=output_attentions,
1058
+ output_hidden_states=output_hidden_states,
1059
+ return_dict=return_dict,
1060
+ is_decoder=is_decoder,
1061
+ )
1062
+
1063
+ sequence_output = outputs[0]
1064
+ if query_embeds is not None:
1065
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1066
+
1067
+ prediction_scores = self.cls(sequence_output)
1068
+
1069
+ if return_logits:
1070
+ return prediction_scores[:, :-1, :].contiguous()
1071
+
1072
+ lm_loss = None
1073
+ if labels is not None:
1074
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1075
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1076
+ labels = labels[:, 1:].contiguous()
1077
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1078
+ lm_loss = loss_fct(
1079
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1080
+ labels.view(-1),
1081
+ )
1082
+ if reduction == "none":
1083
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1084
+
1085
+ if not return_dict:
1086
+ output = (prediction_scores,) + outputs[2:]
1087
+ return ((lm_loss,) + output) if lm_loss is not None else output
1088
+
1089
+ return CausalLMOutputWithCrossAttentions(
1090
+ loss=lm_loss,
1091
+ logits=prediction_scores,
1092
+ past_key_values=outputs.past_key_values,
1093
+ hidden_states=outputs.hidden_states,
1094
+ attentions=outputs.attentions,
1095
+ cross_attentions=outputs.cross_attentions,
1096
+ )
1097
+
1098
+ def prepare_inputs_for_generation(
1099
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1100
+ ):
1101
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1102
+ if attention_mask is None:
1103
+ attention_mask = input_ids.new_ones(input_ids.shape)
1104
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1105
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1106
+
1107
+ # cut decoder_input_ids if past is used
1108
+ if past is not None:
1109
+ input_ids = input_ids[:, -1:]
1110
+
1111
+ return {
1112
+ "input_ids": input_ids,
1113
+ "query_embeds": query_embeds,
1114
+ "attention_mask": attention_mask,
1115
+ "past_key_values": past,
1116
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1117
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1118
+ "is_decoder": True,
1119
+ }
1120
+
1121
+ def _reorder_cache(self, past, beam_idx):
1122
+ reordered_past = ()
1123
+ for layer_past in past:
1124
+ reordered_past += (
1125
+ tuple(
1126
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1127
+ ),
1128
+ )
1129
+ return reordered_past
1130
+
1131
+
1132
+ class BertForMaskedLM(BertPreTrainedModel):
1133
+
1134
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1135
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1136
+
1137
+ def __init__(self, config):
1138
+ super().__init__(config)
1139
+
1140
+ self.bert = BertModel(config, add_pooling_layer=False)
1141
+ self.cls = BertOnlyMLMHead(config)
1142
+
1143
+ self.init_weights()
1144
+
1145
+ def get_output_embeddings(self):
1146
+ return self.cls.predictions.decoder
1147
+
1148
+ def set_output_embeddings(self, new_embeddings):
1149
+ self.cls.predictions.decoder = new_embeddings
1150
+
1151
+ def forward(
1152
+ self,
1153
+ input_ids=None,
1154
+ attention_mask=None,
1155
+ position_ids=None,
1156
+ head_mask=None,
1157
+ query_embeds=None,
1158
+ encoder_hidden_states=None,
1159
+ encoder_attention_mask=None,
1160
+ labels=None,
1161
+ output_attentions=None,
1162
+ output_hidden_states=None,
1163
+ return_dict=None,
1164
+ return_logits=False,
1165
+ is_decoder=False,
1166
+ ):
1167
+ r"""
1168
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1169
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1170
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1171
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1172
+ """
1173
+
1174
+ return_dict = (
1175
+ return_dict if return_dict is not None else self.config.use_return_dict
1176
+ )
1177
+
1178
+ outputs = self.bert(
1179
+ input_ids,
1180
+ attention_mask=attention_mask,
1181
+ position_ids=position_ids,
1182
+ head_mask=head_mask,
1183
+ query_embeds=query_embeds,
1184
+ encoder_hidden_states=encoder_hidden_states,
1185
+ encoder_attention_mask=encoder_attention_mask,
1186
+ output_attentions=output_attentions,
1187
+ output_hidden_states=output_hidden_states,
1188
+ return_dict=return_dict,
1189
+ is_decoder=is_decoder,
1190
+ )
1191
+
1192
+ if query_embeds is not None:
1193
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1194
+ prediction_scores = self.cls(sequence_output)
1195
+
1196
+ if return_logits:
1197
+ return prediction_scores
1198
+
1199
+ masked_lm_loss = None
1200
+ if labels is not None:
1201
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1202
+ masked_lm_loss = loss_fct(
1203
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1204
+ )
1205
+
1206
+ if not return_dict:
1207
+ output = (prediction_scores,) + outputs[2:]
1208
+ return (
1209
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1210
+ )
1211
+
1212
+ return MaskedLMOutput(
1213
+ loss=masked_lm_loss,
1214
+ logits=prediction_scores,
1215
+ hidden_states=outputs.hidden_states,
1216
+ attentions=outputs.attentions,
1217
+ )
models/salmonn.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2023) Tsinghua University, Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modifications to allow device parallelism.
16
+ # Copyright (2024) William Held
17
+ # Same Terms as as above apply
18
+
19
+
20
+ import librosa
21
+ import soundfile as sf
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from peft import LoraConfig, TaskType, get_peft_model
26
+ from transformers import (
27
+ LlamaForCausalLM,
28
+ LlamaTokenizer,
29
+ WhisperFeatureExtractor,
30
+ WhisperModel,
31
+ )
32
+
33
+ from models.beats.BEATs import BEATs, BEATsConfig
34
+ from models.qformer.Qformer import BertConfig, BertLMHeadModel
35
+
36
+
37
+ class SALMONN(nn.Module):
38
+ def __init__(
39
+ self,
40
+ ckpt,
41
+ whisper_path,
42
+ beats_path,
43
+ vicuna_path,
44
+ speech_qformer_token_num=1,
45
+ speech_qformer_layer=2,
46
+ lora=True,
47
+ device="cuda:0",
48
+ lora_alpha=32,
49
+ lora_rank=8,
50
+ lora_dropout=0.1,
51
+ second_per_frame=0.333333,
52
+ second_stride=0.333333,
53
+ low_resource=False,
54
+ ):
55
+
56
+ super().__init__()
57
+
58
+ # feature_extractor
59
+ self.feature_extractor = WhisperFeatureExtractor.from_pretrained(whisper_path)
60
+
61
+ # whisper
62
+ self.speech_encoder = WhisperModel.from_pretrained(whisper_path).encoder.to(
63
+ device
64
+ )
65
+ self.ln_speech = nn.LayerNorm(self.speech_encoder.config.d_model).to(device)
66
+
67
+ # beats
68
+ self.beats_ckpt = beats_path
69
+ beats_checkpoint = torch.load(self.beats_ckpt, map_location=device)
70
+ beats_cfg = BEATsConfig(beats_checkpoint["cfg"])
71
+ beats = BEATs(beats_cfg)
72
+ beats.load_state_dict(beats_checkpoint["model"])
73
+ self.beats = beats
74
+ self.beats.to(device)
75
+ self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim).to(device)
76
+ for name, param in self.beats.named_parameters():
77
+ param.requires_grad = False
78
+ self.beats.eval()
79
+
80
+ # init speech Qformer
81
+ self.speech_Qformer, self.speech_query_tokens = self.init_speech_Qformer(
82
+ speech_qformer_token_num,
83
+ self.speech_encoder.config.d_model + self.beats.cfg.encoder_embed_dim,
84
+ speech_qformer_layer,
85
+ )
86
+ self.speech_Qformer.to(device)
87
+ self.speech_query_tokens.to(device)
88
+ self.second_per_frame = second_per_frame
89
+ self.second_stride = second_stride
90
+
91
+ # vicuna
92
+ if not low_resource:
93
+ self.llama_model = LlamaForCausalLM.from_pretrained(
94
+ vicuna_path,
95
+ torch_dtype=torch.float16,
96
+ device_map="auto",
97
+ )
98
+ else:
99
+ self.llama_model = LlamaForCausalLM.from_pretrained(
100
+ vicuna_path,
101
+ torch_dtype=torch.float16,
102
+ load_in_8bit=True,
103
+ device_map="auto",
104
+ )
105
+
106
+ # lora
107
+ self.lora = lora
108
+ if lora:
109
+ target_modules = None
110
+ self.peft_config = LoraConfig(
111
+ task_type=TaskType.CAUSAL_LM,
112
+ inference_mode=True,
113
+ r=lora_rank,
114
+ lora_alpha=lora_alpha,
115
+ lora_dropout=lora_dropout,
116
+ target_modules=target_modules,
117
+ )
118
+ self.llama_model = get_peft_model(self.llama_model, self.peft_config)
119
+
120
+ # tokenizer
121
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(
122
+ vicuna_path, use_fast=False
123
+ )
124
+ self.llama_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
125
+ self.llama_tokenizer.padding_side = "right"
126
+
127
+ # proj
128
+ self.speech_llama_proj = nn.Linear(
129
+ self.speech_Qformer.config.hidden_size, self.llama_model.config.hidden_size
130
+ ).to(device)
131
+
132
+ # load ckpt
133
+ ckpt_dict = torch.load(ckpt)["model"]
134
+ self.load_state_dict(ckpt_dict, strict=False)
135
+
136
+ def generate(
137
+ self,
138
+ wav_path,
139
+ prompt,
140
+ prompt_pattern="USER: <Speech><SpeechHere></Speech> {}\nASSISTANT:",
141
+ device="cuda:0",
142
+ max_length=200,
143
+ max_new_tokens=128,
144
+ num_beams=1,
145
+ do_sample=True,
146
+ min_length=1,
147
+ top_p=0.9,
148
+ repetition_penalty=1.0,
149
+ length_penalty=1.0,
150
+ temperature=1.0,
151
+ logits_processor=None,
152
+ streamer=None
153
+ ):
154
+ # read wav
155
+ wav, sr = sf.read(wav_path)
156
+ if len(wav.shape) == 2:
157
+ wav = wav[:, 0]
158
+ if len(wav) > 30 * sr:
159
+ wav = wav[: 30 * sr]
160
+ if sr != 16000:
161
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=16000, res_type="fft")
162
+
163
+ # whisper
164
+ spectrogram = self.feature_extractor(
165
+ wav, return_tensors="pt", sampling_rate=16000
166
+ ).input_features.to(
167
+ device
168
+ ) # [1, 80, 3000]
169
+ speech_embeds = self.speech_encoder(
170
+ spectrogram, return_dict=True
171
+ ).last_hidden_state
172
+
173
+ # beats
174
+ raw_wav = torch.from_numpy(wav).to(device).unsqueeze(0)
175
+ audio_padding_mask = torch.zeros(raw_wav.shape, device=device).bool()
176
+ audio_embeds, _ = self.beats.extract_features(
177
+ raw_wav, padding_mask=audio_padding_mask, feature_only=True
178
+ )
179
+
180
+ # auditory embeds
181
+ speech_embeds = self.ln_speech(speech_embeds)
182
+ audio_embeds = self.ln_audio(audio_embeds)
183
+ audio_embeds = F.pad(
184
+ audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))
185
+ )
186
+ speech_embeds = torch.cat([speech_embeds, audio_embeds], dim=-1)
187
+
188
+ # split frames
189
+ B, T, C = speech_embeds.shape
190
+ kernel = round(T * self.second_per_frame / 30.0)
191
+ stride = round(T * self.second_stride / 30.0)
192
+ kernel = (1, kernel)
193
+ stride = (1, stride)
194
+ speech_embeds_tr = speech_embeds.transpose(1, 2).unsqueeze(2)
195
+ speech_embeds_overlap = F.unfold(
196
+ speech_embeds_tr, kernel_size=kernel, dilation=1, padding=0, stride=stride
197
+ )
198
+ _, _, L = speech_embeds_overlap.shape
199
+ speech_embeds_overlap = speech_embeds_overlap.view(B, -1, kernel[1], L)
200
+ speech_embeds_overlap = torch.permute(speech_embeds_overlap, [0, 3, 2, 1])
201
+ speech_embeds = speech_embeds_overlap.reshape(-1, kernel[1], C)
202
+ speech_atts = torch.ones(
203
+ speech_embeds.size()[:-1], dtype=torch.long, device=speech_embeds.device
204
+ )
205
+
206
+ # Qformer
207
+ query_tokens = self.speech_query_tokens.expand(speech_embeds.shape[0], -1, -1)
208
+ query_output = self.speech_Qformer.bert(
209
+ query_embeds=query_tokens.to(device),
210
+ encoder_hidden_states=speech_embeds,
211
+ encoder_attention_mask=speech_atts,
212
+ return_dict=True,
213
+ )
214
+ speech_embeds = self.speech_llama_proj(query_output.last_hidden_state)
215
+ speech_embeds = speech_embeds.view(B, -1, speech_embeds.size(2)).contiguous()
216
+ speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long).to(
217
+ speech_embeds.device
218
+ )
219
+
220
+ # USER: <Speech>speech_embeds<Speech> prompt\nASSISTANT:
221
+ embed_tokens = (
222
+ self.llama_model.model.model.embed_tokens
223
+ if self.lora
224
+ else self.llama_model.model.embed_tokens
225
+ )
226
+ prompt_left, prompts_right = prompt_pattern.format(prompt).split("<SpeechHere>")
227
+ prompt_left_ids = (
228
+ self.llama_tokenizer(
229
+ prompt_left, return_tensors="pt", add_special_tokens=False
230
+ )
231
+ .to(speech_embeds.device)
232
+ .input_ids
233
+ )
234
+ prompt_left_embeds = embed_tokens(prompt_left_ids)
235
+ prompt_right_ids = (
236
+ self.llama_tokenizer(
237
+ prompts_right, return_tensors="pt", add_special_tokens=False
238
+ )
239
+ .to(speech_embeds.device)
240
+ .input_ids
241
+ )
242
+ prompt_right_embeds = embed_tokens(prompt_right_ids)
243
+
244
+ bos_embeds = (
245
+ self.llama_model.model.embed_tokens(
246
+ torch.ones(
247
+ [1, 1],
248
+ dtype=torch.long,
249
+ device=device,
250
+ )
251
+ * self.llama_tokenizer.bos_token_id
252
+ )
253
+ if not self.lora
254
+ else self.llama_model.model.model.embed_tokens(
255
+ torch.ones(
256
+ [1, 1],
257
+ dtype=torch.long,
258
+ device=device,
259
+ )
260
+ * self.llama_tokenizer.bos_token_id
261
+ )
262
+ )
263
+
264
+ embed_list = [bos_embeds, prompt_left_embeds, speech_embeds, prompt_right_embeds]
265
+ embeds = torch.cat(
266
+ [embed.to(bos_embeds.device) for embed in embed_list], dim=1
267
+ )
268
+ atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
269
+ # generate
270
+ output = self.llama_model.generate(
271
+ inputs_embeds=embeds,
272
+ max_length=max_length,
273
+ max_new_tokens=max_new_tokens,
274
+ num_beams=num_beams,
275
+ do_sample=do_sample,
276
+ min_length=min_length,
277
+ top_p=top_p,
278
+ repetition_penalty=repetition_penalty,
279
+ length_penalty=length_penalty,
280
+ temperature=temperature,
281
+ attention_mask=atts,
282
+ bos_token_id=self.llama_tokenizer.bos_token_id,
283
+ eos_token_id=self.llama_tokenizer.eos_token_id,
284
+ pad_token_id=self.llama_tokenizer.pad_token_id,
285
+ logits_processor=[logits_processor] if logits_processor != None else None,
286
+ streamer=streamer,
287
+ )
288
+
289
+ output_text = self.llama_tokenizer.batch_decode(
290
+ output, add_special_tokens=False, skip_special_tokens=True
291
+ )
292
+
293
+ return output_text
294
+
295
+ def init_speech_Qformer(self, num_query_token, speech_width, num_hidden_layers=2):
296
+ encoder_config = BertConfig()
297
+ encoder_config.num_hidden_layers = num_hidden_layers
298
+ encoder_config.encoder_width = speech_width
299
+ encoder_config.add_cross_attention = True
300
+ encoder_config.cross_attention_freq = 1
301
+ encoder_config.query_length = num_query_token
302
+ Qformer = BertLMHeadModel(config=encoder_config)
303
+ query_tokens = nn.Parameter(
304
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
305
+ )
306
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
307
+ return Qformer, query_tokens