Spaces:
Running
on
Zero
Running
on
Zero
SALMONN CODE
Browse files- models/__init__.py +0 -0
- models/beats/BEATs.py +206 -0
- models/beats/LICENSE_beats +21 -0
- models/beats/Tokenizers.py +172 -0
- models/beats/__init__.py +0 -0
- models/beats/backbone.py +814 -0
- models/beats/modules.py +218 -0
- models/beats/quantizer.py +215 -0
- models/qformer/LICENSE_Lavis +14 -0
- models/qformer/LICENSE_MiniGPT4 +14 -0
- models/qformer/LICENSE_VideoLlama +28 -0
- models/qformer/Qformer.py +1217 -0
- models/salmonn.py +307 -0
models/__init__.py
ADDED
File without changes
|
models/beats/BEATs.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
|
11 |
+
import logging
|
12 |
+
from typing import Optional
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torchaudio.compliance.kaldi as ta_kaldi
|
17 |
+
from models.beats.backbone import TransformerEncoder
|
18 |
+
from torch.nn import LayerNorm
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
class BEATsConfig:
|
24 |
+
def __init__(self, cfg=None):
|
25 |
+
self.input_patch_size: int = -1 # path size of patch embedding
|
26 |
+
self.embed_dim: int = 512 # patch embedding dimension
|
27 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
28 |
+
|
29 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
30 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
31 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
32 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
33 |
+
self.activation_fn: str = "gelu" # activation function to use
|
34 |
+
|
35 |
+
self.layer_wise_gradient_decay_ratio: float = (
|
36 |
+
1.0 # ratio for layer-wise gradient decay
|
37 |
+
)
|
38 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
39 |
+
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
40 |
+
|
41 |
+
# dropouts
|
42 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
43 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
44 |
+
self.activation_dropout: float = (
|
45 |
+
0.0 # dropout probability after activation in FFN
|
46 |
+
)
|
47 |
+
self.encoder_layerdrop: float = (
|
48 |
+
0.0 # probability of dropping a tarnsformer layer
|
49 |
+
)
|
50 |
+
self.dropout_input: float = (
|
51 |
+
0.0 # dropout to apply to the input (after feat extr)
|
52 |
+
)
|
53 |
+
|
54 |
+
# positional embeddings
|
55 |
+
self.conv_pos: int = (
|
56 |
+
128 # number of filters for convolutional positional embeddings
|
57 |
+
)
|
58 |
+
self.conv_pos_groups: int = (
|
59 |
+
16 # number of groups for convolutional positional embedding
|
60 |
+
)
|
61 |
+
|
62 |
+
# relative position embedding
|
63 |
+
self.relative_position_embedding: bool = (
|
64 |
+
False # apply relative position embedding
|
65 |
+
)
|
66 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
67 |
+
self.max_distance: int = (
|
68 |
+
1280 # maximum distance for relative position embedding
|
69 |
+
)
|
70 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
71 |
+
|
72 |
+
# label predictor
|
73 |
+
self.finetuned_model: bool = False # whether the model is a fine-tuned model.
|
74 |
+
self.predictor_dropout: float = 0.1 # dropout probability for the predictor
|
75 |
+
self.predictor_class: int = 527 # target class number for the predictor
|
76 |
+
|
77 |
+
if cfg is not None:
|
78 |
+
self.update(cfg)
|
79 |
+
|
80 |
+
def update(self, cfg: dict):
|
81 |
+
self.__dict__.update(cfg)
|
82 |
+
|
83 |
+
|
84 |
+
class BEATs(nn.Module):
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
cfg: BEATsConfig,
|
88 |
+
) -> None:
|
89 |
+
super().__init__()
|
90 |
+
logger.info(f"BEATs Config: {cfg.__dict__}")
|
91 |
+
|
92 |
+
self.cfg = cfg
|
93 |
+
|
94 |
+
self.embed = cfg.embed_dim
|
95 |
+
self.post_extract_proj = (
|
96 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
97 |
+
if self.embed != cfg.encoder_embed_dim
|
98 |
+
else None
|
99 |
+
)
|
100 |
+
|
101 |
+
self.input_patch_size = cfg.input_patch_size
|
102 |
+
self.patch_embedding = nn.Conv2d(
|
103 |
+
1,
|
104 |
+
self.embed,
|
105 |
+
kernel_size=self.input_patch_size,
|
106 |
+
stride=self.input_patch_size,
|
107 |
+
bias=cfg.conv_bias,
|
108 |
+
)
|
109 |
+
|
110 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
111 |
+
|
112 |
+
assert not cfg.deep_norm or not cfg.layer_norm_first
|
113 |
+
self.encoder = TransformerEncoder(cfg)
|
114 |
+
self.layer_norm = LayerNorm(self.embed)
|
115 |
+
|
116 |
+
if cfg.finetuned_model:
|
117 |
+
self.predictor_dropout = nn.Dropout(cfg.predictor_dropout)
|
118 |
+
self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class)
|
119 |
+
else:
|
120 |
+
self.predictor = None
|
121 |
+
|
122 |
+
def forward_padding_mask(
|
123 |
+
self,
|
124 |
+
features: torch.Tensor,
|
125 |
+
padding_mask: torch.Tensor,
|
126 |
+
) -> torch.Tensor:
|
127 |
+
extra = padding_mask.size(1) % features.size(1)
|
128 |
+
if extra > 0:
|
129 |
+
padding_mask = padding_mask[:, :-extra]
|
130 |
+
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
|
131 |
+
padding_mask = padding_mask.all(-1)
|
132 |
+
return padding_mask
|
133 |
+
|
134 |
+
def preprocess(
|
135 |
+
self,
|
136 |
+
source: torch.Tensor,
|
137 |
+
fbank_mean: float = 15.41663,
|
138 |
+
fbank_std: float = 6.55582,
|
139 |
+
) -> torch.Tensor:
|
140 |
+
fbanks = []
|
141 |
+
for waveform in source:
|
142 |
+
waveform = waveform.unsqueeze(0) * 2**15
|
143 |
+
fbank = ta_kaldi.fbank(
|
144 |
+
waveform,
|
145 |
+
num_mel_bins=128,
|
146 |
+
sample_frequency=16000,
|
147 |
+
frame_length=25,
|
148 |
+
frame_shift=10,
|
149 |
+
)
|
150 |
+
fbanks.append(fbank)
|
151 |
+
fbank = torch.stack(fbanks, dim=0)
|
152 |
+
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
153 |
+
return fbank
|
154 |
+
|
155 |
+
def extract_features(
|
156 |
+
self,
|
157 |
+
source: torch.Tensor,
|
158 |
+
padding_mask: Optional[torch.Tensor] = None,
|
159 |
+
fbank_mean: float = 15.41663,
|
160 |
+
fbank_std: float = 6.55582,
|
161 |
+
feature_only=False,
|
162 |
+
):
|
163 |
+
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(
|
164 |
+
torch.float32
|
165 |
+
)
|
166 |
+
|
167 |
+
if padding_mask is not None:
|
168 |
+
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
169 |
+
|
170 |
+
fbank = fbank.unsqueeze(1)
|
171 |
+
features = self.patch_embedding(fbank)
|
172 |
+
features = features.reshape(features.shape[0], features.shape[1], -1)
|
173 |
+
features = features.transpose(1, 2)
|
174 |
+
features = self.layer_norm(features)
|
175 |
+
|
176 |
+
if padding_mask is not None:
|
177 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
178 |
+
|
179 |
+
if self.post_extract_proj is not None:
|
180 |
+
features = self.post_extract_proj(features)
|
181 |
+
|
182 |
+
x = self.dropout_input(features)
|
183 |
+
|
184 |
+
x, layer_results = self.encoder(
|
185 |
+
x,
|
186 |
+
padding_mask=padding_mask,
|
187 |
+
)
|
188 |
+
|
189 |
+
if not feature_only and self.predictor is not None:
|
190 |
+
x = self.predictor_dropout(x)
|
191 |
+
logits = self.predictor(x)
|
192 |
+
|
193 |
+
if padding_mask is not None and padding_mask.any():
|
194 |
+
logits[padding_mask] = 0
|
195 |
+
logits = logits.sum(dim=1)
|
196 |
+
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
|
197 |
+
logits
|
198 |
+
)
|
199 |
+
else:
|
200 |
+
logits = logits.mean(dim=1)
|
201 |
+
|
202 |
+
lprobs = torch.sigmoid(logits)
|
203 |
+
|
204 |
+
return lprobs, padding_mask
|
205 |
+
else:
|
206 |
+
return x, padding_mask
|
models/beats/LICENSE_beats
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The MIT License (MIT)
|
2 |
+
|
3 |
+
Copyright (c) Microsoft Corporation
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
models/beats/Tokenizers.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.nn import LayerNorm
|
14 |
+
import torchaudio.compliance.kaldi as ta_kaldi
|
15 |
+
|
16 |
+
from beats.backbone import (
|
17 |
+
TransformerEncoder,
|
18 |
+
)
|
19 |
+
from beats.quantizer import (
|
20 |
+
NormEMAVectorQuantizer,
|
21 |
+
)
|
22 |
+
|
23 |
+
import logging
|
24 |
+
from typing import Optional
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
class TokenizersConfig:
|
30 |
+
def __init__(self, cfg=None):
|
31 |
+
self.input_patch_size: int = -1 # path size of patch embedding
|
32 |
+
self.embed_dim: int = 512 # patch embedding dimension
|
33 |
+
self.conv_bias: bool = False # include bias in conv encoder
|
34 |
+
|
35 |
+
self.encoder_layers: int = 12 # num encoder layers in the transformer
|
36 |
+
self.encoder_embed_dim: int = 768 # encoder embedding dimension
|
37 |
+
self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
|
38 |
+
self.encoder_attention_heads: int = 12 # num encoder attention heads
|
39 |
+
self.activation_fn: str = "gelu" # activation function to use
|
40 |
+
|
41 |
+
self.layer_norm_first: bool = False # apply layernorm first in the transformer
|
42 |
+
self.deep_norm: bool = False # apply deep_norm first in the transformer
|
43 |
+
|
44 |
+
# dropouts
|
45 |
+
self.dropout: float = 0.1 # dropout probability for the transformer
|
46 |
+
self.attention_dropout: float = 0.1 # dropout probability for attention weights
|
47 |
+
self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
|
48 |
+
self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
|
49 |
+
self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
|
50 |
+
|
51 |
+
# positional embeddings
|
52 |
+
self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
|
53 |
+
self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
|
54 |
+
|
55 |
+
# relative position embedding
|
56 |
+
self.relative_position_embedding: bool = False # apply relative position embedding
|
57 |
+
self.num_buckets: int = 320 # number of buckets for relative position embedding
|
58 |
+
self.max_distance: int = 1280 # maximum distance for relative position embedding
|
59 |
+
self.gru_rel_pos: bool = False # apply gated relative position embedding
|
60 |
+
|
61 |
+
# quantizer
|
62 |
+
self.quant_n: int = 1024 # codebook number in quantizer
|
63 |
+
self.quant_dim: int = 256 # codebook dimension in quantizer
|
64 |
+
|
65 |
+
if cfg is not None:
|
66 |
+
self.update(cfg)
|
67 |
+
|
68 |
+
def update(self, cfg: dict):
|
69 |
+
self.__dict__.update(cfg)
|
70 |
+
|
71 |
+
|
72 |
+
class Tokenizers(nn.Module):
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
cfg: TokenizersConfig,
|
76 |
+
) -> None:
|
77 |
+
super().__init__()
|
78 |
+
logger.info(f"Tokenizers Config: {cfg.__dict__}")
|
79 |
+
|
80 |
+
self.cfg = cfg
|
81 |
+
|
82 |
+
self.embed = cfg.embed_dim
|
83 |
+
self.post_extract_proj = (
|
84 |
+
nn.Linear(self.embed, cfg.encoder_embed_dim)
|
85 |
+
if self.embed != cfg.encoder_embed_dim
|
86 |
+
else None
|
87 |
+
)
|
88 |
+
|
89 |
+
self.input_patch_size = cfg.input_patch_size
|
90 |
+
self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size,
|
91 |
+
bias=cfg.conv_bias)
|
92 |
+
|
93 |
+
self.dropout_input = nn.Dropout(cfg.dropout_input)
|
94 |
+
|
95 |
+
assert not cfg.deep_norm or not cfg.layer_norm_first
|
96 |
+
self.encoder = TransformerEncoder(cfg)
|
97 |
+
self.layer_norm = LayerNorm(self.embed)
|
98 |
+
|
99 |
+
self.quantize = NormEMAVectorQuantizer(
|
100 |
+
n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99,
|
101 |
+
)
|
102 |
+
self.quant_n = cfg.quant_n
|
103 |
+
self.quantize_layer = nn.Sequential(
|
104 |
+
nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim),
|
105 |
+
nn.Tanh(),
|
106 |
+
nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize
|
107 |
+
)
|
108 |
+
|
109 |
+
def forward_padding_mask(
|
110 |
+
self,
|
111 |
+
features: torch.Tensor,
|
112 |
+
padding_mask: torch.Tensor,
|
113 |
+
) -> torch.Tensor:
|
114 |
+
extra = padding_mask.size(1) % features.size(1)
|
115 |
+
if extra > 0:
|
116 |
+
padding_mask = padding_mask[:, :-extra]
|
117 |
+
padding_mask = padding_mask.view(
|
118 |
+
padding_mask.size(0), features.size(1), -1
|
119 |
+
)
|
120 |
+
padding_mask = padding_mask.all(-1)
|
121 |
+
return padding_mask
|
122 |
+
|
123 |
+
def preprocess(
|
124 |
+
self,
|
125 |
+
source: torch.Tensor,
|
126 |
+
fbank_mean: float = 15.41663,
|
127 |
+
fbank_std: float = 6.55582,
|
128 |
+
) -> torch.Tensor:
|
129 |
+
fbanks = []
|
130 |
+
for waveform in source:
|
131 |
+
waveform = waveform.unsqueeze(0) * 2 ** 15
|
132 |
+
fbank = ta_kaldi.fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10)
|
133 |
+
fbanks.append(fbank)
|
134 |
+
fbank = torch.stack(fbanks, dim=0)
|
135 |
+
fbank = (fbank - fbank_mean) / (2 * fbank_std)
|
136 |
+
return fbank
|
137 |
+
|
138 |
+
def extract_labels(
|
139 |
+
self,
|
140 |
+
source: torch.Tensor,
|
141 |
+
padding_mask: Optional[torch.Tensor] = None,
|
142 |
+
fbank_mean: float = 15.41663,
|
143 |
+
fbank_std: float = 6.55582,
|
144 |
+
):
|
145 |
+
fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std)
|
146 |
+
|
147 |
+
if padding_mask is not None:
|
148 |
+
padding_mask = self.forward_padding_mask(fbank, padding_mask)
|
149 |
+
|
150 |
+
fbank = fbank.unsqueeze(1)
|
151 |
+
features = self.patch_embedding(fbank)
|
152 |
+
features = features.reshape(features.shape[0], features.shape[1], -1)
|
153 |
+
features = features.transpose(1, 2)
|
154 |
+
features = self.layer_norm(features)
|
155 |
+
|
156 |
+
if padding_mask is not None:
|
157 |
+
padding_mask = self.forward_padding_mask(features, padding_mask)
|
158 |
+
|
159 |
+
if self.post_extract_proj is not None:
|
160 |
+
features = self.post_extract_proj(features)
|
161 |
+
|
162 |
+
x = self.dropout_input(features)
|
163 |
+
|
164 |
+
x, layer_results = self.encoder(
|
165 |
+
x,
|
166 |
+
padding_mask=padding_mask,
|
167 |
+
)
|
168 |
+
|
169 |
+
quantize_input = self.quantize_layer(x)
|
170 |
+
quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input)
|
171 |
+
|
172 |
+
return embed_ind
|
models/beats/__init__.py
ADDED
File without changes
|
models/beats/backbone.py
ADDED
@@ -0,0 +1,814 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
from typing import Dict, Optional, Tuple
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from models.beats.modules import (
|
17 |
+
GLU_Linear,
|
18 |
+
GradMultiply,
|
19 |
+
SamePad,
|
20 |
+
get_activation_fn,
|
21 |
+
quant_noise,
|
22 |
+
)
|
23 |
+
from torch import Tensor, nn
|
24 |
+
from torch.nn import LayerNorm, Parameter
|
25 |
+
|
26 |
+
|
27 |
+
class TransformerEncoder(nn.Module):
|
28 |
+
def __init__(self, args):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
self.dropout = args.dropout
|
32 |
+
self.embedding_dim = args.encoder_embed_dim
|
33 |
+
|
34 |
+
self.pos_conv = nn.Conv1d(
|
35 |
+
self.embedding_dim,
|
36 |
+
self.embedding_dim,
|
37 |
+
kernel_size=args.conv_pos,
|
38 |
+
padding=args.conv_pos // 2,
|
39 |
+
groups=args.conv_pos_groups,
|
40 |
+
)
|
41 |
+
dropout = 0
|
42 |
+
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
|
43 |
+
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
|
44 |
+
nn.init.constant_(self.pos_conv.bias, 0)
|
45 |
+
|
46 |
+
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
|
47 |
+
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
|
48 |
+
|
49 |
+
if hasattr(args, "relative_position_embedding"):
|
50 |
+
self.relative_position_embedding = args.relative_position_embedding
|
51 |
+
self.num_buckets = args.num_buckets
|
52 |
+
self.max_distance = args.max_distance
|
53 |
+
else:
|
54 |
+
self.relative_position_embedding = False
|
55 |
+
self.num_buckets = 0
|
56 |
+
self.max_distance = 0
|
57 |
+
|
58 |
+
self.layers = nn.ModuleList(
|
59 |
+
[
|
60 |
+
TransformerSentenceEncoderLayer(
|
61 |
+
embedding_dim=self.embedding_dim,
|
62 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
63 |
+
num_attention_heads=args.encoder_attention_heads,
|
64 |
+
dropout=self.dropout,
|
65 |
+
attention_dropout=args.attention_dropout,
|
66 |
+
activation_dropout=args.activation_dropout,
|
67 |
+
activation_fn=args.activation_fn,
|
68 |
+
layer_norm_first=args.layer_norm_first,
|
69 |
+
deep_norm=args.deep_norm,
|
70 |
+
has_relative_attention_bias=self.relative_position_embedding,
|
71 |
+
num_buckets=self.num_buckets,
|
72 |
+
max_distance=self.max_distance,
|
73 |
+
gru_rel_pos=args.gru_rel_pos,
|
74 |
+
encoder_layers=args.encoder_layers,
|
75 |
+
)
|
76 |
+
for i in range(args.encoder_layers)
|
77 |
+
]
|
78 |
+
)
|
79 |
+
if self.relative_position_embedding:
|
80 |
+
for i in range(1, args.encoder_layers):
|
81 |
+
del self.layers[i].self_attn.relative_attention_bias
|
82 |
+
self.layers[i].self_attn.relative_attention_bias = self.layers[
|
83 |
+
0
|
84 |
+
].self_attn.relative_attention_bias
|
85 |
+
|
86 |
+
self.layer_norm_first = args.layer_norm_first
|
87 |
+
self.layer_norm = LayerNorm(self.embedding_dim)
|
88 |
+
self.layerdrop = args.encoder_layerdrop
|
89 |
+
|
90 |
+
self.apply(init_bert_params)
|
91 |
+
|
92 |
+
if args.deep_norm:
|
93 |
+
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4)
|
94 |
+
for i in range(args.encoder_layers):
|
95 |
+
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1)
|
96 |
+
nn.init.xavier_normal_(
|
97 |
+
self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta
|
98 |
+
)
|
99 |
+
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1)
|
100 |
+
nn.init.xavier_normal_(
|
101 |
+
self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta
|
102 |
+
)
|
103 |
+
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta)
|
104 |
+
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta)
|
105 |
+
|
106 |
+
self.layer_wise_gradient_decay_ratio = getattr(
|
107 |
+
args, "layer_wise_gradient_decay_ratio", 1
|
108 |
+
)
|
109 |
+
|
110 |
+
def forward(self, x, padding_mask=None, layer=None):
|
111 |
+
x, layer_results = self.extract_features(x, padding_mask, layer)
|
112 |
+
|
113 |
+
if self.layer_norm_first and layer is None:
|
114 |
+
x = self.layer_norm(x)
|
115 |
+
|
116 |
+
return x, layer_results
|
117 |
+
|
118 |
+
def extract_features(self, x, padding_mask=None, tgt_layer=None):
|
119 |
+
|
120 |
+
if padding_mask is not None:
|
121 |
+
x[padding_mask] = 0
|
122 |
+
|
123 |
+
x_conv = self.pos_conv(x.transpose(1, 2))
|
124 |
+
x_conv = x_conv.transpose(1, 2)
|
125 |
+
x = x + x_conv
|
126 |
+
|
127 |
+
if not self.layer_norm_first:
|
128 |
+
x = self.layer_norm(x)
|
129 |
+
|
130 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
131 |
+
|
132 |
+
# B x T x C -> T x B x C
|
133 |
+
x = x.transpose(0, 1)
|
134 |
+
|
135 |
+
layer_results = []
|
136 |
+
z = None
|
137 |
+
if tgt_layer is not None:
|
138 |
+
layer_results.append((x, z))
|
139 |
+
r = None
|
140 |
+
pos_bias = None
|
141 |
+
for i, layer in enumerate(self.layers):
|
142 |
+
if self.layer_wise_gradient_decay_ratio != 1.0:
|
143 |
+
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio)
|
144 |
+
dropout_probability = np.random.random()
|
145 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
146 |
+
x, z, pos_bias = layer(
|
147 |
+
x,
|
148 |
+
self_attn_padding_mask=padding_mask,
|
149 |
+
need_weights=False,
|
150 |
+
pos_bias=pos_bias,
|
151 |
+
)
|
152 |
+
if tgt_layer is not None:
|
153 |
+
layer_results.append((x, z))
|
154 |
+
if i == tgt_layer:
|
155 |
+
r = x
|
156 |
+
break
|
157 |
+
|
158 |
+
if r is not None:
|
159 |
+
x = r
|
160 |
+
|
161 |
+
# T x B x C -> B x T x C
|
162 |
+
x = x.transpose(0, 1)
|
163 |
+
|
164 |
+
return x, layer_results
|
165 |
+
|
166 |
+
|
167 |
+
class TransformerSentenceEncoderLayer(nn.Module):
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
embedding_dim: float = 768,
|
171 |
+
ffn_embedding_dim: float = 3072,
|
172 |
+
num_attention_heads: float = 8,
|
173 |
+
dropout: float = 0.1,
|
174 |
+
attention_dropout: float = 0.1,
|
175 |
+
activation_dropout: float = 0.1,
|
176 |
+
activation_fn: str = "relu",
|
177 |
+
layer_norm_first: bool = False,
|
178 |
+
deep_norm: bool = False,
|
179 |
+
has_relative_attention_bias: bool = False,
|
180 |
+
num_buckets: int = 0,
|
181 |
+
max_distance: int = 0,
|
182 |
+
rescale_init: bool = False,
|
183 |
+
gru_rel_pos: bool = False,
|
184 |
+
encoder_layers: int = 0,
|
185 |
+
) -> None:
|
186 |
+
|
187 |
+
super().__init__()
|
188 |
+
self.embedding_dim = embedding_dim
|
189 |
+
self.dropout = dropout
|
190 |
+
self.activation_dropout = activation_dropout
|
191 |
+
|
192 |
+
self.activation_name = activation_fn
|
193 |
+
self.activation_fn = get_activation_fn(activation_fn)
|
194 |
+
self.self_attn = MultiheadAttention(
|
195 |
+
self.embedding_dim,
|
196 |
+
num_attention_heads,
|
197 |
+
dropout=attention_dropout,
|
198 |
+
self_attention=True,
|
199 |
+
has_relative_attention_bias=has_relative_attention_bias,
|
200 |
+
num_buckets=num_buckets,
|
201 |
+
max_distance=max_distance,
|
202 |
+
rescale_init=rescale_init,
|
203 |
+
gru_rel_pos=gru_rel_pos,
|
204 |
+
)
|
205 |
+
|
206 |
+
self.dropout1 = nn.Dropout(dropout)
|
207 |
+
self.dropout2 = nn.Dropout(self.activation_dropout)
|
208 |
+
self.dropout3 = nn.Dropout(dropout)
|
209 |
+
|
210 |
+
self.layer_norm_first = layer_norm_first
|
211 |
+
|
212 |
+
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
|
213 |
+
|
214 |
+
if self.activation_name == "glu":
|
215 |
+
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
|
216 |
+
else:
|
217 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
218 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
219 |
+
|
220 |
+
self.final_layer_norm = LayerNorm(self.embedding_dim)
|
221 |
+
|
222 |
+
self.deep_norm = deep_norm
|
223 |
+
if self.deep_norm:
|
224 |
+
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4)
|
225 |
+
else:
|
226 |
+
self.deep_norm_alpha = 1
|
227 |
+
|
228 |
+
def forward(
|
229 |
+
self,
|
230 |
+
x: torch.Tensor,
|
231 |
+
self_attn_mask: torch.Tensor = None,
|
232 |
+
self_attn_padding_mask: torch.Tensor = None,
|
233 |
+
need_weights: bool = False,
|
234 |
+
pos_bias=None,
|
235 |
+
):
|
236 |
+
residual = x
|
237 |
+
|
238 |
+
if self.layer_norm_first:
|
239 |
+
x = self.self_attn_layer_norm(x)
|
240 |
+
x, attn, pos_bias = self.self_attn(
|
241 |
+
query=x,
|
242 |
+
key=x,
|
243 |
+
value=x,
|
244 |
+
key_padding_mask=self_attn_padding_mask,
|
245 |
+
need_weights=False,
|
246 |
+
attn_mask=self_attn_mask,
|
247 |
+
position_bias=pos_bias,
|
248 |
+
)
|
249 |
+
x = self.dropout1(x)
|
250 |
+
x = residual + x
|
251 |
+
|
252 |
+
residual = x
|
253 |
+
x = self.final_layer_norm(x)
|
254 |
+
if self.activation_name == "glu":
|
255 |
+
x = self.fc1(x)
|
256 |
+
else:
|
257 |
+
x = self.activation_fn(self.fc1(x))
|
258 |
+
x = self.dropout2(x)
|
259 |
+
x = self.fc2(x)
|
260 |
+
x = self.dropout3(x)
|
261 |
+
x = residual + x
|
262 |
+
else:
|
263 |
+
x, attn, pos_bias = self.self_attn(
|
264 |
+
query=x,
|
265 |
+
key=x,
|
266 |
+
value=x,
|
267 |
+
key_padding_mask=self_attn_padding_mask,
|
268 |
+
need_weights=need_weights,
|
269 |
+
attn_mask=self_attn_mask,
|
270 |
+
position_bias=pos_bias,
|
271 |
+
)
|
272 |
+
|
273 |
+
x = self.dropout1(x)
|
274 |
+
x = residual * self.deep_norm_alpha + x
|
275 |
+
|
276 |
+
x = self.self_attn_layer_norm(x)
|
277 |
+
|
278 |
+
residual = x
|
279 |
+
if self.activation_name == "glu":
|
280 |
+
x = self.fc1(x)
|
281 |
+
else:
|
282 |
+
x = self.activation_fn(self.fc1(x))
|
283 |
+
x = self.dropout2(x)
|
284 |
+
x = self.fc2(x)
|
285 |
+
x = self.dropout3(x)
|
286 |
+
x = residual * self.deep_norm_alpha + x
|
287 |
+
x = self.final_layer_norm(x)
|
288 |
+
|
289 |
+
return x, attn, pos_bias
|
290 |
+
|
291 |
+
|
292 |
+
class MultiheadAttention(nn.Module):
|
293 |
+
"""Multi-headed attention.
|
294 |
+
|
295 |
+
See "Attention Is All You Need" for more details.
|
296 |
+
"""
|
297 |
+
|
298 |
+
def __init__(
|
299 |
+
self,
|
300 |
+
embed_dim,
|
301 |
+
num_heads,
|
302 |
+
kdim=None,
|
303 |
+
vdim=None,
|
304 |
+
dropout=0.0,
|
305 |
+
bias=True,
|
306 |
+
add_bias_kv=False,
|
307 |
+
add_zero_attn=False,
|
308 |
+
self_attention=False,
|
309 |
+
encoder_decoder_attention=False,
|
310 |
+
q_noise=0.0,
|
311 |
+
qn_block_size=8,
|
312 |
+
has_relative_attention_bias=False,
|
313 |
+
num_buckets=32,
|
314 |
+
max_distance=128,
|
315 |
+
gru_rel_pos=False,
|
316 |
+
rescale_init=False,
|
317 |
+
):
|
318 |
+
super().__init__()
|
319 |
+
self.embed_dim = embed_dim
|
320 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
321 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
322 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
323 |
+
|
324 |
+
self.num_heads = num_heads
|
325 |
+
self.dropout_module = nn.Dropout(dropout)
|
326 |
+
|
327 |
+
self.has_relative_attention_bias = has_relative_attention_bias
|
328 |
+
self.num_buckets = num_buckets
|
329 |
+
self.max_distance = max_distance
|
330 |
+
if self.has_relative_attention_bias:
|
331 |
+
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
|
332 |
+
|
333 |
+
self.head_dim = embed_dim // num_heads
|
334 |
+
self.q_head_dim = self.head_dim
|
335 |
+
self.k_head_dim = self.head_dim
|
336 |
+
assert (
|
337 |
+
self.head_dim * num_heads == self.embed_dim
|
338 |
+
), "embed_dim must be divisible by num_heads"
|
339 |
+
self.scaling = self.head_dim**-0.5
|
340 |
+
|
341 |
+
self.self_attention = self_attention
|
342 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
343 |
+
|
344 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
345 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
346 |
+
)
|
347 |
+
|
348 |
+
k_bias = True
|
349 |
+
if rescale_init:
|
350 |
+
k_bias = False
|
351 |
+
|
352 |
+
k_embed_dim = embed_dim
|
353 |
+
q_embed_dim = embed_dim
|
354 |
+
|
355 |
+
self.k_proj = quant_noise(
|
356 |
+
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
|
357 |
+
)
|
358 |
+
self.v_proj = quant_noise(
|
359 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
360 |
+
)
|
361 |
+
self.q_proj = quant_noise(
|
362 |
+
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
|
363 |
+
)
|
364 |
+
|
365 |
+
self.out_proj = quant_noise(
|
366 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
367 |
+
)
|
368 |
+
|
369 |
+
if add_bias_kv:
|
370 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
371 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
372 |
+
else:
|
373 |
+
self.bias_k = self.bias_v = None
|
374 |
+
|
375 |
+
self.add_zero_attn = add_zero_attn
|
376 |
+
|
377 |
+
self.gru_rel_pos = gru_rel_pos
|
378 |
+
if self.gru_rel_pos:
|
379 |
+
self.grep_linear = nn.Linear(self.q_head_dim, 8)
|
380 |
+
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
|
381 |
+
|
382 |
+
self.reset_parameters()
|
383 |
+
|
384 |
+
def reset_parameters(self):
|
385 |
+
if self.qkv_same_dim:
|
386 |
+
# Empirically observed the convergence to be much better with
|
387 |
+
# the scaled initialization
|
388 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
389 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
390 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
391 |
+
else:
|
392 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
393 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
394 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
395 |
+
|
396 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
397 |
+
if self.out_proj.bias is not None:
|
398 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
399 |
+
if self.bias_k is not None:
|
400 |
+
nn.init.xavier_normal_(self.bias_k)
|
401 |
+
if self.bias_v is not None:
|
402 |
+
nn.init.xavier_normal_(self.bias_v)
|
403 |
+
if self.has_relative_attention_bias:
|
404 |
+
nn.init.xavier_normal_(self.relative_attention_bias.weight)
|
405 |
+
|
406 |
+
def _relative_positions_bucket(self, relative_positions, bidirectional=True):
|
407 |
+
num_buckets = self.num_buckets
|
408 |
+
max_distance = self.max_distance
|
409 |
+
relative_buckets = 0
|
410 |
+
|
411 |
+
if bidirectional:
|
412 |
+
num_buckets = num_buckets // 2
|
413 |
+
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
|
414 |
+
relative_positions = torch.abs(relative_positions)
|
415 |
+
else:
|
416 |
+
relative_positions = -torch.min(
|
417 |
+
relative_positions, torch.zeros_like(relative_positions)
|
418 |
+
)
|
419 |
+
|
420 |
+
max_exact = num_buckets // 2
|
421 |
+
is_small = relative_positions < max_exact
|
422 |
+
|
423 |
+
relative_postion_if_large = max_exact + (
|
424 |
+
torch.log(relative_positions.float() / max_exact)
|
425 |
+
/ math.log(max_distance / max_exact)
|
426 |
+
* (num_buckets - max_exact)
|
427 |
+
).to(torch.long)
|
428 |
+
relative_postion_if_large = torch.min(
|
429 |
+
relative_postion_if_large,
|
430 |
+
torch.full_like(relative_postion_if_large, num_buckets - 1),
|
431 |
+
)
|
432 |
+
|
433 |
+
relative_buckets += torch.where(
|
434 |
+
is_small, relative_positions, relative_postion_if_large
|
435 |
+
)
|
436 |
+
return relative_buckets
|
437 |
+
|
438 |
+
def compute_bias(self, query_length, key_length):
|
439 |
+
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
|
440 |
+
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
|
441 |
+
relative_position = memory_position - context_position
|
442 |
+
relative_position_bucket = self._relative_positions_bucket(
|
443 |
+
relative_position, bidirectional=True
|
444 |
+
)
|
445 |
+
relative_position_bucket = relative_position_bucket.to(
|
446 |
+
self.relative_attention_bias.weight.device
|
447 |
+
)
|
448 |
+
values = self.relative_attention_bias(relative_position_bucket)
|
449 |
+
values = values.permute([2, 0, 1])
|
450 |
+
return values
|
451 |
+
|
452 |
+
def forward(
|
453 |
+
self,
|
454 |
+
query,
|
455 |
+
key: Optional[Tensor],
|
456 |
+
value: Optional[Tensor],
|
457 |
+
key_padding_mask: Optional[Tensor] = None,
|
458 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
459 |
+
need_weights: bool = True,
|
460 |
+
static_kv: bool = False,
|
461 |
+
attn_mask: Optional[Tensor] = None,
|
462 |
+
before_softmax: bool = False,
|
463 |
+
need_head_weights: bool = False,
|
464 |
+
position_bias: Optional[Tensor] = None,
|
465 |
+
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
466 |
+
"""Input shape: Time x Batch x Channel
|
467 |
+
|
468 |
+
Args:
|
469 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
470 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
471 |
+
padding elements are indicated by 1s.
|
472 |
+
need_weights (bool, optional): return the attention weights,
|
473 |
+
averaged over heads (default: False).
|
474 |
+
attn_mask (ByteTensor, optional): typically used to
|
475 |
+
implement causal attention, where the mask prevents the
|
476 |
+
attention from looking forward in time (default: None).
|
477 |
+
before_softmax (bool, optional): return the raw attention
|
478 |
+
weights and values before the attention softmax.
|
479 |
+
need_head_weights (bool, optional): return the attention
|
480 |
+
weights for each head. Implies *need_weights*. Default:
|
481 |
+
return the average attention weights over all heads.
|
482 |
+
"""
|
483 |
+
if need_head_weights:
|
484 |
+
need_weights = True
|
485 |
+
|
486 |
+
is_tpu = query.device.type == "xla"
|
487 |
+
|
488 |
+
tgt_len, bsz, embed_dim = query.size()
|
489 |
+
src_len = tgt_len
|
490 |
+
assert embed_dim == self.embed_dim
|
491 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
492 |
+
if key is not None:
|
493 |
+
src_len, key_bsz, _ = key.size()
|
494 |
+
if not torch.jit.is_scripting():
|
495 |
+
assert key_bsz == bsz
|
496 |
+
assert value is not None
|
497 |
+
assert src_len, bsz == value.shape[:2]
|
498 |
+
|
499 |
+
if self.has_relative_attention_bias and position_bias is None:
|
500 |
+
position_bias = self.compute_bias(tgt_len, src_len)
|
501 |
+
position_bias = (
|
502 |
+
position_bias.unsqueeze(0)
|
503 |
+
.repeat(bsz, 1, 1, 1)
|
504 |
+
.view(bsz * self.num_heads, tgt_len, src_len)
|
505 |
+
)
|
506 |
+
|
507 |
+
if incremental_state is not None:
|
508 |
+
saved_state = self._get_input_buffer(incremental_state)
|
509 |
+
if saved_state is not None and "prev_key" in saved_state:
|
510 |
+
# previous time steps are cached - no need to recompute
|
511 |
+
# key and value if they are static
|
512 |
+
if static_kv:
|
513 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
514 |
+
key = value = None
|
515 |
+
else:
|
516 |
+
saved_state = None
|
517 |
+
|
518 |
+
if self.self_attention:
|
519 |
+
q = self.q_proj(query)
|
520 |
+
k = self.k_proj(query)
|
521 |
+
v = self.v_proj(query)
|
522 |
+
elif self.encoder_decoder_attention:
|
523 |
+
# encoder-decoder attention
|
524 |
+
q = self.q_proj(query)
|
525 |
+
if key is None:
|
526 |
+
assert value is None
|
527 |
+
k = v = None
|
528 |
+
else:
|
529 |
+
k = self.k_proj(key)
|
530 |
+
v = self.v_proj(key)
|
531 |
+
|
532 |
+
else:
|
533 |
+
assert key is not None and value is not None
|
534 |
+
q = self.q_proj(query)
|
535 |
+
k = self.k_proj(key)
|
536 |
+
v = self.v_proj(value)
|
537 |
+
q *= self.scaling
|
538 |
+
alpha = 32
|
539 |
+
q *= 1 / alpha
|
540 |
+
|
541 |
+
if self.bias_k is not None:
|
542 |
+
assert self.bias_v is not None
|
543 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
544 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
545 |
+
if attn_mask is not None:
|
546 |
+
attn_mask = torch.cat(
|
547 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
548 |
+
)
|
549 |
+
if key_padding_mask is not None:
|
550 |
+
key_padding_mask = torch.cat(
|
551 |
+
[
|
552 |
+
key_padding_mask,
|
553 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
554 |
+
],
|
555 |
+
dim=1,
|
556 |
+
)
|
557 |
+
|
558 |
+
q = (
|
559 |
+
q.contiguous()
|
560 |
+
.view(tgt_len, bsz * self.num_heads, self.q_head_dim)
|
561 |
+
.transpose(0, 1)
|
562 |
+
)
|
563 |
+
if k is not None:
|
564 |
+
k = (
|
565 |
+
k.contiguous()
|
566 |
+
.view(-1, bsz * self.num_heads, self.k_head_dim)
|
567 |
+
.transpose(0, 1)
|
568 |
+
)
|
569 |
+
if v is not None:
|
570 |
+
v = (
|
571 |
+
v.contiguous()
|
572 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
573 |
+
.transpose(0, 1)
|
574 |
+
)
|
575 |
+
|
576 |
+
if saved_state is not None:
|
577 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
578 |
+
if "prev_key" in saved_state:
|
579 |
+
_prev_key = saved_state["prev_key"]
|
580 |
+
assert _prev_key is not None
|
581 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
582 |
+
if static_kv:
|
583 |
+
k = prev_key
|
584 |
+
else:
|
585 |
+
assert k is not None
|
586 |
+
k = torch.cat([prev_key, k], dim=1)
|
587 |
+
src_len = k.size(1)
|
588 |
+
if "prev_value" in saved_state:
|
589 |
+
_prev_value = saved_state["prev_value"]
|
590 |
+
assert _prev_value is not None
|
591 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
592 |
+
if static_kv:
|
593 |
+
v = prev_value
|
594 |
+
else:
|
595 |
+
assert v is not None
|
596 |
+
v = torch.cat([prev_value, v], dim=1)
|
597 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
598 |
+
if "prev_key_padding_mask" in saved_state:
|
599 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
600 |
+
assert k is not None and v is not None
|
601 |
+
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
|
602 |
+
key_padding_mask=key_padding_mask,
|
603 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
604 |
+
batch_size=bsz,
|
605 |
+
src_len=k.size(1),
|
606 |
+
static_kv=static_kv,
|
607 |
+
)
|
608 |
+
|
609 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
610 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
611 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
612 |
+
# In this branch incremental_state is never None
|
613 |
+
assert incremental_state is not None
|
614 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
615 |
+
assert k is not None
|
616 |
+
assert k.size(1) == src_len
|
617 |
+
|
618 |
+
# This is part of a workaround to get around fork/join parallelism
|
619 |
+
# not supporting Optional types.
|
620 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
621 |
+
key_padding_mask = None
|
622 |
+
|
623 |
+
if key_padding_mask is not None:
|
624 |
+
assert key_padding_mask.size(0) == bsz
|
625 |
+
assert key_padding_mask.size(1) == src_len
|
626 |
+
|
627 |
+
if self.add_zero_attn:
|
628 |
+
assert v is not None
|
629 |
+
src_len += 1
|
630 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
631 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
632 |
+
if attn_mask is not None:
|
633 |
+
attn_mask = torch.cat(
|
634 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
635 |
+
)
|
636 |
+
if key_padding_mask is not None:
|
637 |
+
key_padding_mask = torch.cat(
|
638 |
+
[
|
639 |
+
key_padding_mask,
|
640 |
+
torch.zeros(key_padding_mask.size(0), 1).type_as(
|
641 |
+
key_padding_mask
|
642 |
+
),
|
643 |
+
],
|
644 |
+
dim=1,
|
645 |
+
)
|
646 |
+
|
647 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
648 |
+
attn_weights = (
|
649 |
+
attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]
|
650 |
+
) * alpha
|
651 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
652 |
+
|
653 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
654 |
+
|
655 |
+
if attn_mask is not None:
|
656 |
+
attn_mask = attn_mask.unsqueeze(0)
|
657 |
+
attn_weights += attn_mask
|
658 |
+
|
659 |
+
if key_padding_mask is not None:
|
660 |
+
# don't attend to padding symbols
|
661 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
662 |
+
if not is_tpu:
|
663 |
+
attn_weights = attn_weights.masked_fill(
|
664 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
665 |
+
float("-inf"),
|
666 |
+
)
|
667 |
+
else:
|
668 |
+
attn_weights = attn_weights.transpose(0, 2)
|
669 |
+
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
|
670 |
+
attn_weights = attn_weights.transpose(0, 2)
|
671 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
672 |
+
|
673 |
+
if before_softmax:
|
674 |
+
return attn_weights, v, position_bias
|
675 |
+
|
676 |
+
if position_bias is not None:
|
677 |
+
attn_mask_rel_pos = position_bias
|
678 |
+
if self.gru_rel_pos == 1:
|
679 |
+
query_layer = (
|
680 |
+
q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
|
681 |
+
* alpha
|
682 |
+
/ self.scaling
|
683 |
+
)
|
684 |
+
_B, _H, _L, __ = query_layer.size()
|
685 |
+
gate_a, gate_b = torch.sigmoid(
|
686 |
+
self.grep_linear(query_layer)
|
687 |
+
.view(_B, _H, _L, 2, 4)
|
688 |
+
.sum(-1, keepdim=False)
|
689 |
+
).chunk(2, dim=-1)
|
690 |
+
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
|
691 |
+
attn_mask_rel_pos = (
|
692 |
+
gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias
|
693 |
+
)
|
694 |
+
|
695 |
+
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size())
|
696 |
+
|
697 |
+
attn_weights = attn_weights + attn_mask_rel_pos
|
698 |
+
|
699 |
+
attn_weights_float = F.softmax(attn_weights, dim=-1)
|
700 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
701 |
+
attn_probs = self.dropout_module(attn_weights)
|
702 |
+
|
703 |
+
assert v is not None
|
704 |
+
attn = torch.bmm(attn_probs, v)
|
705 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
706 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
707 |
+
attn = self.out_proj(attn)
|
708 |
+
attn_weights: Optional[Tensor] = None
|
709 |
+
if need_weights:
|
710 |
+
attn_weights = attn_weights_float.view(
|
711 |
+
bsz, self.num_heads, tgt_len, src_len
|
712 |
+
).transpose(1, 0)
|
713 |
+
if not need_head_weights:
|
714 |
+
# average attention weights over heads
|
715 |
+
attn_weights = attn_weights.mean(dim=0)
|
716 |
+
|
717 |
+
return attn, attn_weights, position_bias
|
718 |
+
|
719 |
+
@staticmethod
|
720 |
+
def _append_prev_key_padding_mask(
|
721 |
+
key_padding_mask: Optional[Tensor],
|
722 |
+
prev_key_padding_mask: Optional[Tensor],
|
723 |
+
batch_size: int,
|
724 |
+
src_len: int,
|
725 |
+
static_kv: bool,
|
726 |
+
) -> Optional[Tensor]:
|
727 |
+
# saved key padding masks have shape (bsz, seq_len)
|
728 |
+
if prev_key_padding_mask is not None and static_kv:
|
729 |
+
new_key_padding_mask = prev_key_padding_mask
|
730 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
731 |
+
new_key_padding_mask = torch.cat(
|
732 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
733 |
+
)
|
734 |
+
# During incremental decoding, as the padding token enters and
|
735 |
+
# leaves the frame, there will be a time when prev or current
|
736 |
+
# is None
|
737 |
+
elif prev_key_padding_mask is not None:
|
738 |
+
if src_len > prev_key_padding_mask.size(1):
|
739 |
+
filler = torch.zeros(
|
740 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
741 |
+
device=prev_key_padding_mask.device,
|
742 |
+
)
|
743 |
+
new_key_padding_mask = torch.cat(
|
744 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
745 |
+
)
|
746 |
+
else:
|
747 |
+
new_key_padding_mask = prev_key_padding_mask.float()
|
748 |
+
elif key_padding_mask is not None:
|
749 |
+
if src_len > key_padding_mask.size(1):
|
750 |
+
filler = torch.zeros(
|
751 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
752 |
+
device=key_padding_mask.device,
|
753 |
+
)
|
754 |
+
new_key_padding_mask = torch.cat(
|
755 |
+
[filler.float(), key_padding_mask.float()], dim=1
|
756 |
+
)
|
757 |
+
else:
|
758 |
+
new_key_padding_mask = key_padding_mask.float()
|
759 |
+
else:
|
760 |
+
new_key_padding_mask = prev_key_padding_mask
|
761 |
+
return new_key_padding_mask
|
762 |
+
|
763 |
+
def _get_input_buffer(
|
764 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
765 |
+
) -> Dict[str, Optional[Tensor]]:
|
766 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
767 |
+
if result is not None:
|
768 |
+
return result
|
769 |
+
else:
|
770 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
771 |
+
return empty_result
|
772 |
+
|
773 |
+
def _set_input_buffer(
|
774 |
+
self,
|
775 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
776 |
+
buffer: Dict[str, Optional[Tensor]],
|
777 |
+
):
|
778 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
779 |
+
|
780 |
+
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
781 |
+
return attn_weights
|
782 |
+
|
783 |
+
|
784 |
+
def init_bert_params(module):
|
785 |
+
"""
|
786 |
+
Initialize the weights specific to the BERT Model.
|
787 |
+
This overrides the default initializations depending on the specified arguments.
|
788 |
+
1. If normal_init_linear_weights is set then weights of linear
|
789 |
+
layer will be initialized using the normal distribution and
|
790 |
+
bais will be set to the specified value.
|
791 |
+
2. If normal_init_embed_weights is set then weights of embedding
|
792 |
+
layer will be initialized using the normal distribution.
|
793 |
+
3. If normal_init_proj_weights is set then weights of
|
794 |
+
in_project_weight for MultiHeadAttention initialized using
|
795 |
+
the normal distribution (to be validated).
|
796 |
+
"""
|
797 |
+
|
798 |
+
def normal_(data):
|
799 |
+
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
800 |
+
# so that the RNG is consistent with and without FSDP
|
801 |
+
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
802 |
+
|
803 |
+
if isinstance(module, nn.Linear):
|
804 |
+
normal_(module.weight.data)
|
805 |
+
if module.bias is not None:
|
806 |
+
module.bias.data.zero_()
|
807 |
+
if isinstance(module, nn.Embedding):
|
808 |
+
normal_(module.weight.data)
|
809 |
+
if module.padding_idx is not None:
|
810 |
+
module.weight.data[module.padding_idx].zero_()
|
811 |
+
if isinstance(module, MultiheadAttention):
|
812 |
+
normal_(module.q_proj.weight.data)
|
813 |
+
normal_(module.k_proj.weight.data)
|
814 |
+
normal_(module.v_proj.weight.data)
|
models/beats/modules.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on fairseq code bases
|
7 |
+
# https://github.com/pytorch/fairseq
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
import warnings
|
12 |
+
import torch
|
13 |
+
from torch import Tensor, nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
|
17 |
+
class GradMultiply(torch.autograd.Function):
|
18 |
+
@staticmethod
|
19 |
+
def forward(ctx, x, scale):
|
20 |
+
ctx.scale = scale
|
21 |
+
res = x.new(x)
|
22 |
+
return res
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def backward(ctx, grad):
|
26 |
+
return grad * ctx.scale, None
|
27 |
+
|
28 |
+
|
29 |
+
class SamePad(nn.Module):
|
30 |
+
def __init__(self, kernel_size, causal=False):
|
31 |
+
super().__init__()
|
32 |
+
if causal:
|
33 |
+
self.remove = kernel_size - 1
|
34 |
+
else:
|
35 |
+
self.remove = 1 if kernel_size % 2 == 0 else 0
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
if self.remove > 0:
|
39 |
+
x = x[:, :, : -self.remove]
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class Swish(nn.Module):
|
44 |
+
def __init__(self):
|
45 |
+
super(Swish, self).__init__()
|
46 |
+
self.act = torch.nn.Sigmoid()
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
return x * self.act(x)
|
50 |
+
|
51 |
+
|
52 |
+
class GLU_Linear(nn.Module):
|
53 |
+
def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
|
54 |
+
super(GLU_Linear, self).__init__()
|
55 |
+
|
56 |
+
self.glu_type = glu_type
|
57 |
+
self.output_dim = output_dim
|
58 |
+
|
59 |
+
if glu_type == "sigmoid":
|
60 |
+
self.glu_act = torch.nn.Sigmoid()
|
61 |
+
elif glu_type == "swish":
|
62 |
+
self.glu_act = Swish()
|
63 |
+
elif glu_type == "relu":
|
64 |
+
self.glu_act = torch.nn.ReLU()
|
65 |
+
elif glu_type == "gelu":
|
66 |
+
self.glu_act = torch.nn.GELU()
|
67 |
+
|
68 |
+
if bias_in_glu:
|
69 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, True)
|
70 |
+
else:
|
71 |
+
self.linear = nn.Linear(input_dim, output_dim * 2, False)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
# to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
|
75 |
+
x = self.linear(x)
|
76 |
+
|
77 |
+
if self.glu_type == "bilinear":
|
78 |
+
x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
|
79 |
+
else:
|
80 |
+
x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
|
81 |
+
|
82 |
+
return x
|
83 |
+
|
84 |
+
|
85 |
+
def gelu_accurate(x):
|
86 |
+
if not hasattr(gelu_accurate, "_a"):
|
87 |
+
gelu_accurate._a = math.sqrt(2 / math.pi)
|
88 |
+
return (
|
89 |
+
0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
|
90 |
+
)
|
91 |
+
|
92 |
+
|
93 |
+
def gelu(x: torch.Tensor) -> torch.Tensor:
|
94 |
+
return torch.nn.functional.gelu(x.float()).type_as(x)
|
95 |
+
|
96 |
+
|
97 |
+
def get_activation_fn(activation: str):
|
98 |
+
"""Returns the activation function corresponding to `activation`"""
|
99 |
+
|
100 |
+
if activation == "relu":
|
101 |
+
return F.relu
|
102 |
+
elif activation == "gelu":
|
103 |
+
return gelu
|
104 |
+
elif activation == "gelu_fast":
|
105 |
+
warnings.warn(
|
106 |
+
"--activation-fn=gelu_fast has been renamed to gelu_accurate"
|
107 |
+
)
|
108 |
+
return gelu_accurate
|
109 |
+
elif activation == "gelu_accurate":
|
110 |
+
return gelu_accurate
|
111 |
+
elif activation == "tanh":
|
112 |
+
return torch.tanh
|
113 |
+
elif activation == "linear":
|
114 |
+
return lambda x: x
|
115 |
+
elif activation == "glu":
|
116 |
+
return lambda x: x
|
117 |
+
else:
|
118 |
+
raise RuntimeError("--activation-fn {} not supported".format(activation))
|
119 |
+
|
120 |
+
|
121 |
+
def quant_noise(module, p, block_size):
|
122 |
+
"""
|
123 |
+
Wraps modules and applies quantization noise to the weights for
|
124 |
+
subsequent quantization with Iterative Product Quantization as
|
125 |
+
described in "Training with Quantization Noise for Extreme Model Compression"
|
126 |
+
|
127 |
+
Args:
|
128 |
+
- module: nn.Module
|
129 |
+
- p: amount of Quantization Noise
|
130 |
+
- block_size: size of the blocks for subsequent quantization with iPQ
|
131 |
+
|
132 |
+
Remarks:
|
133 |
+
- Module weights must have the right sizes wrt the block size
|
134 |
+
- Only Linear, Embedding and Conv2d modules are supported for the moment
|
135 |
+
- For more detail on how to quantize by blocks with convolutional weights,
|
136 |
+
see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
|
137 |
+
- We implement the simplest form of noise here as stated in the paper
|
138 |
+
which consists in randomly dropping blocks
|
139 |
+
"""
|
140 |
+
|
141 |
+
# if no quantization noise, don't register hook
|
142 |
+
if p <= 0:
|
143 |
+
return module
|
144 |
+
|
145 |
+
# supported modules
|
146 |
+
assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
|
147 |
+
|
148 |
+
# test whether module.weight has the right sizes wrt block_size
|
149 |
+
is_conv = module.weight.ndim == 4
|
150 |
+
|
151 |
+
# 2D matrix
|
152 |
+
if not is_conv:
|
153 |
+
assert (
|
154 |
+
module.weight.size(1) % block_size == 0
|
155 |
+
), "Input features must be a multiple of block sizes"
|
156 |
+
|
157 |
+
# 4D matrix
|
158 |
+
else:
|
159 |
+
# 1x1 convolutions
|
160 |
+
if module.kernel_size == (1, 1):
|
161 |
+
assert (
|
162 |
+
module.in_channels % block_size == 0
|
163 |
+
), "Input channels must be a multiple of block sizes"
|
164 |
+
# regular convolutions
|
165 |
+
else:
|
166 |
+
k = module.kernel_size[0] * module.kernel_size[1]
|
167 |
+
assert k % block_size == 0, "Kernel size must be a multiple of block size"
|
168 |
+
|
169 |
+
def _forward_pre_hook(mod, input):
|
170 |
+
# no noise for evaluation
|
171 |
+
if mod.training:
|
172 |
+
if not is_conv:
|
173 |
+
# gather weight and sizes
|
174 |
+
weight = mod.weight
|
175 |
+
in_features = weight.size(1)
|
176 |
+
out_features = weight.size(0)
|
177 |
+
|
178 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
179 |
+
mask = torch.zeros(
|
180 |
+
in_features // block_size * out_features, device=weight.device
|
181 |
+
)
|
182 |
+
mask.bernoulli_(p)
|
183 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
|
184 |
+
|
185 |
+
else:
|
186 |
+
# gather weight and sizes
|
187 |
+
weight = mod.weight
|
188 |
+
in_channels = mod.in_channels
|
189 |
+
out_channels = mod.out_channels
|
190 |
+
|
191 |
+
# split weight matrix into blocks and randomly drop selected blocks
|
192 |
+
if mod.kernel_size == (1, 1):
|
193 |
+
mask = torch.zeros(
|
194 |
+
int(in_channels // block_size * out_channels),
|
195 |
+
device=weight.device,
|
196 |
+
)
|
197 |
+
mask.bernoulli_(p)
|
198 |
+
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
|
199 |
+
else:
|
200 |
+
mask = torch.zeros(
|
201 |
+
weight.size(0), weight.size(1), device=weight.device
|
202 |
+
)
|
203 |
+
mask.bernoulli_(p)
|
204 |
+
mask = (
|
205 |
+
mask.unsqueeze(2)
|
206 |
+
.unsqueeze(3)
|
207 |
+
.repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
|
208 |
+
)
|
209 |
+
|
210 |
+
# scale weights and apply mask
|
211 |
+
mask = mask.to(
|
212 |
+
torch.bool
|
213 |
+
) # x.bool() is not currently supported in TorchScript
|
214 |
+
s = 1 / (1 - p)
|
215 |
+
mod.weight.data = s * weight.masked_fill(mask, 0)
|
216 |
+
|
217 |
+
module.register_forward_pre_hook(_forward_pre_hook)
|
218 |
+
return module
|
models/beats/quantizer.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058)
|
3 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beats
|
4 |
+
# Copyright (c) 2022 Microsoft
|
5 |
+
# Licensed under The MIT License [see LICENSE for details]
|
6 |
+
# Based on VQGAN code bases
|
7 |
+
# https://github.com/CompVis/taming-transformers
|
8 |
+
# --------------------------------------------------------'
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torch.distributed as distributed
|
14 |
+
|
15 |
+
try:
|
16 |
+
from einops import rearrange, repeat
|
17 |
+
except ImportError:
|
18 |
+
pass
|
19 |
+
|
20 |
+
|
21 |
+
def l2norm(t):
|
22 |
+
return F.normalize(t, p=2, dim=-1)
|
23 |
+
|
24 |
+
|
25 |
+
def ema_inplace(moving_avg, new, decay):
|
26 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
27 |
+
|
28 |
+
|
29 |
+
def sample_vectors(samples, num):
|
30 |
+
num_samples, device = samples.shape[0], samples.device
|
31 |
+
|
32 |
+
if num_samples >= num:
|
33 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
34 |
+
else:
|
35 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
36 |
+
|
37 |
+
return samples[indices]
|
38 |
+
|
39 |
+
|
40 |
+
def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False):
|
41 |
+
dim, dtype, device = samples.shape[-1], samples.dtype, samples.device
|
42 |
+
|
43 |
+
means = sample_vectors(samples, num_clusters)
|
44 |
+
|
45 |
+
for _ in range(num_iters):
|
46 |
+
if use_cosine_sim:
|
47 |
+
dists = samples @ means.t()
|
48 |
+
else:
|
49 |
+
diffs = rearrange(samples, 'n d -> n () d') \
|
50 |
+
- rearrange(means, 'c d -> () c d')
|
51 |
+
dists = -(diffs ** 2).sum(dim=-1)
|
52 |
+
|
53 |
+
buckets = dists.max(dim=-1).indices
|
54 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
55 |
+
zero_mask = bins == 0
|
56 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
57 |
+
|
58 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
59 |
+
new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples)
|
60 |
+
new_means = new_means / bins_min_clamped[..., None]
|
61 |
+
|
62 |
+
if use_cosine_sim:
|
63 |
+
new_means = l2norm(new_means)
|
64 |
+
|
65 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
66 |
+
|
67 |
+
return means, bins
|
68 |
+
|
69 |
+
|
70 |
+
class EmbeddingEMA(nn.Module):
|
71 |
+
def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''):
|
72 |
+
super().__init__()
|
73 |
+
self.num_tokens = num_tokens
|
74 |
+
self.codebook_dim = codebook_dim
|
75 |
+
self.decay = decay
|
76 |
+
self.eps = eps
|
77 |
+
if codebook_init_path == '':
|
78 |
+
if not kmeans_init:
|
79 |
+
weight = torch.randn(num_tokens, codebook_dim)
|
80 |
+
weight = l2norm(weight)
|
81 |
+
else:
|
82 |
+
weight = torch.zeros(num_tokens, codebook_dim)
|
83 |
+
self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
84 |
+
else:
|
85 |
+
print(f"load init codebook weight from {codebook_init_path}")
|
86 |
+
codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu')
|
87 |
+
weight = codebook_ckpt_weight.clone()
|
88 |
+
self.register_buffer('initted', torch.Tensor([True]))
|
89 |
+
|
90 |
+
self.weight = nn.Parameter(weight, requires_grad=False)
|
91 |
+
self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
|
92 |
+
self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
|
93 |
+
# self.register_buffer('initted', torch.Tensor([not kmeans_init]))
|
94 |
+
self.update = True
|
95 |
+
|
96 |
+
@torch.jit.ignore
|
97 |
+
def init_embed_(self, data):
|
98 |
+
if self.initted:
|
99 |
+
return
|
100 |
+
print("Performing Kemans init for codebook")
|
101 |
+
embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True)
|
102 |
+
self.weight.data.copy_(embed)
|
103 |
+
self.cluster_size.data.copy_(cluster_size)
|
104 |
+
self.initted.data.copy_(torch.Tensor([True]))
|
105 |
+
|
106 |
+
def forward(self, embed_id):
|
107 |
+
return F.embedding(embed_id, self.weight)
|
108 |
+
|
109 |
+
def cluster_size_ema_update(self, new_cluster_size):
|
110 |
+
self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
|
111 |
+
|
112 |
+
def embed_avg_ema_update(self, new_embed_avg):
|
113 |
+
self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
|
114 |
+
|
115 |
+
def weight_update(self, num_tokens):
|
116 |
+
n = self.cluster_size.sum()
|
117 |
+
smoothed_cluster_size = (
|
118 |
+
(self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
|
119 |
+
)
|
120 |
+
# normalize embedding average with smoothed cluster size
|
121 |
+
embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
|
122 |
+
# embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1))
|
123 |
+
self.weight.data.copy_(embed_normalized)
|
124 |
+
|
125 |
+
|
126 |
+
def norm_ema_inplace(moving_avg, new, decay):
|
127 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
128 |
+
moving_avg.data.copy_(l2norm(moving_avg.data))
|
129 |
+
|
130 |
+
|
131 |
+
class NormEMAVectorQuantizer(nn.Module):
|
132 |
+
def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
|
133 |
+
statistic_code_usage=True, kmeans_init=False, codebook_init_path=''):
|
134 |
+
super().__init__()
|
135 |
+
self.codebook_dim = embedding_dim
|
136 |
+
self.num_tokens = n_embed
|
137 |
+
self.beta = beta
|
138 |
+
self.decay = decay
|
139 |
+
|
140 |
+
# learnable = True if orthogonal_reg_weight > 0 else False
|
141 |
+
self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path)
|
142 |
+
|
143 |
+
self.statistic_code_usage = statistic_code_usage
|
144 |
+
if statistic_code_usage:
|
145 |
+
self.register_buffer('cluster_size', torch.zeros(n_embed))
|
146 |
+
if distributed.is_available() and distributed.is_initialized():
|
147 |
+
print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!")
|
148 |
+
self.all_reduce_fn = distributed.all_reduce
|
149 |
+
else:
|
150 |
+
self.all_reduce_fn = nn.Identity()
|
151 |
+
|
152 |
+
def reset_cluster_size(self, device):
|
153 |
+
if self.statistic_code_usage:
|
154 |
+
self.register_buffer('cluster_size', torch.zeros(self.num_tokens))
|
155 |
+
self.cluster_size = self.cluster_size.to(device)
|
156 |
+
|
157 |
+
def forward(self, z):
|
158 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
159 |
+
# z, 'b c h w -> b h w c'
|
160 |
+
# z = rearrange(z, 'b c h w -> b h w c')
|
161 |
+
# z = z.transpose(1, 2)
|
162 |
+
z = l2norm(z)
|
163 |
+
z_flattened = z.reshape(-1, self.codebook_dim)
|
164 |
+
|
165 |
+
self.embedding.init_embed_(z_flattened)
|
166 |
+
|
167 |
+
d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
|
168 |
+
self.embedding.weight.pow(2).sum(dim=1) - 2 * \
|
169 |
+
torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
|
170 |
+
|
171 |
+
encoding_indices = torch.argmin(d, dim=1)
|
172 |
+
|
173 |
+
z_q = self.embedding(encoding_indices).view(z.shape)
|
174 |
+
|
175 |
+
encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
|
176 |
+
|
177 |
+
if not self.training:
|
178 |
+
with torch.no_grad():
|
179 |
+
cluster_size = encodings.sum(0)
|
180 |
+
self.all_reduce_fn(cluster_size)
|
181 |
+
ema_inplace(self.cluster_size, cluster_size, self.decay)
|
182 |
+
|
183 |
+
if self.training and self.embedding.update:
|
184 |
+
# EMA cluster size
|
185 |
+
|
186 |
+
bins = encodings.sum(0)
|
187 |
+
self.all_reduce_fn(bins)
|
188 |
+
|
189 |
+
# self.embedding.cluster_size_ema_update(bins)
|
190 |
+
ema_inplace(self.cluster_size, bins, self.decay)
|
191 |
+
|
192 |
+
zero_mask = (bins == 0)
|
193 |
+
bins = bins.masked_fill(zero_mask, 1.)
|
194 |
+
|
195 |
+
embed_sum = z_flattened.t() @ encodings
|
196 |
+
self.all_reduce_fn(embed_sum)
|
197 |
+
|
198 |
+
embed_normalized = (embed_sum / bins.unsqueeze(0)).t()
|
199 |
+
embed_normalized = l2norm(embed_normalized)
|
200 |
+
|
201 |
+
embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight,
|
202 |
+
embed_normalized)
|
203 |
+
norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay)
|
204 |
+
|
205 |
+
# compute loss for embedding
|
206 |
+
loss = self.beta * F.mse_loss(z_q.detach(), z)
|
207 |
+
|
208 |
+
# preserve gradients
|
209 |
+
z_q = z + (z_q - z).detach()
|
210 |
+
|
211 |
+
# reshape back to match original input shape
|
212 |
+
# z_q, 'b h w c -> b c h w'
|
213 |
+
# z_q = rearrange(z_q, 'b h w c -> b c h w')
|
214 |
+
# z_q = z_q.transpose(1, 2)
|
215 |
+
return z_q, loss, encoding_indices
|
models/qformer/LICENSE_Lavis
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Salesforce, Inc.
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
11 |
+
|
12 |
+
3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
models/qformer/LICENSE_MiniGPT4
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright 2023 Deyao Zhu
|
4 |
+
All rights reserved.
|
5 |
+
|
6 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
9 |
+
|
10 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
11 |
+
|
12 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
13 |
+
|
14 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
models/qformer/LICENSE_VideoLlama
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2023, Multilingual NLP Team at Alibaba DAMO Academy
|
4 |
+
|
5 |
+
Redistribution and use in source and binary forms, with or without
|
6 |
+
modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
list of conditions and the following disclaimer.
|
10 |
+
|
11 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
this list of conditions and the following disclaimer in the documentation
|
13 |
+
and/or other materials provided with the distribution.
|
14 |
+
|
15 |
+
3. Neither the name of the copyright holder nor the names of its
|
16 |
+
contributors may be used to endorse or promote products derived from
|
17 |
+
this software without specific prior written permission.
|
18 |
+
|
19 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
models/qformer/Qformer.py
ADDED
@@ -0,0 +1,1217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adapted from salesforce@LAVIS. Below is the original copyright:
|
3 |
+
* Copyright (c) 2023, salesforce.com, inc.
|
4 |
+
* All rights reserved.
|
5 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
6 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
7 |
+
* By Junnan Li
|
8 |
+
* Based on huggingface code base
|
9 |
+
* https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
|
10 |
+
"""
|
11 |
+
|
12 |
+
import math
|
13 |
+
import os
|
14 |
+
import warnings
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import Optional, Tuple, Dict, Any
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import Tensor, device, dtype, nn
|
20 |
+
import torch.utils.checkpoint
|
21 |
+
from torch import nn
|
22 |
+
from torch.nn import CrossEntropyLoss
|
23 |
+
import torch.nn.functional as F
|
24 |
+
|
25 |
+
from transformers.activations import ACT2FN
|
26 |
+
from transformers.file_utils import (
|
27 |
+
ModelOutput,
|
28 |
+
)
|
29 |
+
from transformers.modeling_outputs import (
|
30 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
31 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
32 |
+
CausalLMOutputWithCrossAttentions,
|
33 |
+
MaskedLMOutput,
|
34 |
+
MultipleChoiceModelOutput,
|
35 |
+
NextSentencePredictorOutput,
|
36 |
+
QuestionAnsweringModelOutput,
|
37 |
+
SequenceClassifierOutput,
|
38 |
+
TokenClassifierOutput,
|
39 |
+
)
|
40 |
+
from transformers.modeling_utils import (
|
41 |
+
PreTrainedModel,
|
42 |
+
apply_chunking_to_forward,
|
43 |
+
find_pruneable_heads_and_indices,
|
44 |
+
prune_linear_layer,
|
45 |
+
)
|
46 |
+
from transformers.utils import logging
|
47 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
48 |
+
|
49 |
+
logger = logging.get_logger(__name__)
|
50 |
+
|
51 |
+
|
52 |
+
class BertEmbeddings(nn.Module):
|
53 |
+
"""Construct the embeddings from word and position embeddings."""
|
54 |
+
|
55 |
+
def __init__(self, config):
|
56 |
+
super().__init__()
|
57 |
+
self.word_embeddings = nn.Embedding(
|
58 |
+
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
|
59 |
+
)
|
60 |
+
self.position_embeddings = nn.Embedding(
|
61 |
+
config.max_position_embeddings, config.hidden_size
|
62 |
+
)
|
63 |
+
|
64 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
65 |
+
# any TensorFlow checkpoint file
|
66 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
67 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
68 |
+
|
69 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
70 |
+
self.register_buffer(
|
71 |
+
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
|
72 |
+
)
|
73 |
+
self.position_embedding_type = getattr(
|
74 |
+
config, "position_embedding_type", "absolute"
|
75 |
+
)
|
76 |
+
|
77 |
+
self.config = config
|
78 |
+
|
79 |
+
def forward(
|
80 |
+
self,
|
81 |
+
input_ids=None,
|
82 |
+
position_ids=None,
|
83 |
+
query_embeds=None,
|
84 |
+
past_key_values_length=0,
|
85 |
+
):
|
86 |
+
if input_ids is not None:
|
87 |
+
seq_length = input_ids.size()[1]
|
88 |
+
else:
|
89 |
+
seq_length = 0
|
90 |
+
|
91 |
+
if position_ids is None:
|
92 |
+
position_ids = self.position_ids[
|
93 |
+
:, past_key_values_length : seq_length + past_key_values_length
|
94 |
+
].clone()
|
95 |
+
|
96 |
+
if input_ids is not None:
|
97 |
+
embeddings = self.word_embeddings(input_ids)
|
98 |
+
if self.position_embedding_type == "absolute":
|
99 |
+
position_embeddings = self.position_embeddings(position_ids)
|
100 |
+
embeddings = embeddings + position_embeddings
|
101 |
+
|
102 |
+
if query_embeds is not None:
|
103 |
+
embeddings = torch.cat((query_embeds, embeddings), dim=1)
|
104 |
+
else:
|
105 |
+
embeddings = query_embeds
|
106 |
+
|
107 |
+
embeddings = self.LayerNorm(embeddings)
|
108 |
+
embeddings = self.dropout(embeddings)
|
109 |
+
return embeddings
|
110 |
+
|
111 |
+
|
112 |
+
class BertSelfAttention(nn.Module):
|
113 |
+
def __init__(self, config, is_cross_attention):
|
114 |
+
super().__init__()
|
115 |
+
self.config = config
|
116 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
117 |
+
config, "embedding_size"
|
118 |
+
):
|
119 |
+
raise ValueError(
|
120 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
121 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
122 |
+
)
|
123 |
+
|
124 |
+
self.num_attention_heads = config.num_attention_heads
|
125 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
126 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
127 |
+
|
128 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
129 |
+
if is_cross_attention:
|
130 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
131 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
132 |
+
else:
|
133 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
134 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
135 |
+
|
136 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
137 |
+
self.position_embedding_type = getattr(
|
138 |
+
config, "position_embedding_type", "absolute"
|
139 |
+
)
|
140 |
+
if (
|
141 |
+
self.position_embedding_type == "relative_key"
|
142 |
+
or self.position_embedding_type == "relative_key_query"
|
143 |
+
):
|
144 |
+
self.max_position_embeddings = config.max_position_embeddings
|
145 |
+
self.distance_embedding = nn.Embedding(
|
146 |
+
2 * config.max_position_embeddings - 1, self.attention_head_size
|
147 |
+
)
|
148 |
+
self.save_attention = False
|
149 |
+
|
150 |
+
def save_attn_gradients(self, attn_gradients):
|
151 |
+
self.attn_gradients = attn_gradients
|
152 |
+
|
153 |
+
def get_attn_gradients(self):
|
154 |
+
return self.attn_gradients
|
155 |
+
|
156 |
+
def save_attention_map(self, attention_map):
|
157 |
+
self.attention_map = attention_map
|
158 |
+
|
159 |
+
def get_attention_map(self):
|
160 |
+
return self.attention_map
|
161 |
+
|
162 |
+
def transpose_for_scores(self, x):
|
163 |
+
new_x_shape = x.size()[:-1] + (
|
164 |
+
self.num_attention_heads,
|
165 |
+
self.attention_head_size,
|
166 |
+
)
|
167 |
+
x = x.view(*new_x_shape)
|
168 |
+
return x.permute(0, 2, 1, 3)
|
169 |
+
|
170 |
+
def forward(
|
171 |
+
self,
|
172 |
+
hidden_states,
|
173 |
+
attention_mask=None,
|
174 |
+
head_mask=None,
|
175 |
+
encoder_hidden_states=None,
|
176 |
+
encoder_attention_mask=None,
|
177 |
+
past_key_value=None,
|
178 |
+
output_attentions=False,
|
179 |
+
):
|
180 |
+
|
181 |
+
# If this is instantiated as a cross-attention module, the keys
|
182 |
+
# and values come from an encoder; the attention mask needs to be
|
183 |
+
# such that the encoder's padding tokens are not attended to.
|
184 |
+
is_cross_attention = encoder_hidden_states is not None
|
185 |
+
|
186 |
+
if is_cross_attention:
|
187 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
188 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
189 |
+
attention_mask = encoder_attention_mask
|
190 |
+
elif past_key_value is not None:
|
191 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
192 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
193 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
194 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
195 |
+
else:
|
196 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
197 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
198 |
+
|
199 |
+
mixed_query_layer = self.query(hidden_states)
|
200 |
+
|
201 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
202 |
+
|
203 |
+
past_key_value = (key_layer, value_layer)
|
204 |
+
|
205 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
206 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
207 |
+
|
208 |
+
if (
|
209 |
+
self.position_embedding_type == "relative_key"
|
210 |
+
or self.position_embedding_type == "relative_key_query"
|
211 |
+
):
|
212 |
+
seq_length = hidden_states.size()[1]
|
213 |
+
position_ids_l = torch.arange(
|
214 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
215 |
+
).view(-1, 1)
|
216 |
+
position_ids_r = torch.arange(
|
217 |
+
seq_length, dtype=torch.long, device=hidden_states.device
|
218 |
+
).view(1, -1)
|
219 |
+
distance = position_ids_l - position_ids_r
|
220 |
+
positional_embedding = self.distance_embedding(
|
221 |
+
distance + self.max_position_embeddings - 1
|
222 |
+
)
|
223 |
+
positional_embedding = positional_embedding.to(
|
224 |
+
dtype=query_layer.dtype
|
225 |
+
) # fp16 compatibility
|
226 |
+
|
227 |
+
if self.position_embedding_type == "relative_key":
|
228 |
+
relative_position_scores = torch.einsum(
|
229 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
230 |
+
)
|
231 |
+
attention_scores = attention_scores + relative_position_scores
|
232 |
+
elif self.position_embedding_type == "relative_key_query":
|
233 |
+
relative_position_scores_query = torch.einsum(
|
234 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding
|
235 |
+
)
|
236 |
+
relative_position_scores_key = torch.einsum(
|
237 |
+
"bhrd,lrd->bhlr", key_layer, positional_embedding
|
238 |
+
)
|
239 |
+
attention_scores = (
|
240 |
+
attention_scores
|
241 |
+
+ relative_position_scores_query
|
242 |
+
+ relative_position_scores_key
|
243 |
+
)
|
244 |
+
|
245 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
246 |
+
if attention_mask is not None:
|
247 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
248 |
+
attention_scores = attention_scores + attention_mask
|
249 |
+
|
250 |
+
# Normalize the attention scores to probabilities.
|
251 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
252 |
+
|
253 |
+
if is_cross_attention and self.save_attention:
|
254 |
+
self.save_attention_map(attention_probs)
|
255 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
256 |
+
|
257 |
+
# This is actually dropping out entire tokens to attend to, which might
|
258 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
259 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
260 |
+
|
261 |
+
# Mask heads if we want to
|
262 |
+
if head_mask is not None:
|
263 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
264 |
+
|
265 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
266 |
+
|
267 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
268 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
269 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
270 |
+
|
271 |
+
outputs = (
|
272 |
+
(context_layer, attention_probs) if output_attentions else (context_layer,)
|
273 |
+
)
|
274 |
+
|
275 |
+
outputs = outputs + (past_key_value,)
|
276 |
+
return outputs
|
277 |
+
|
278 |
+
|
279 |
+
class BertSelfOutput(nn.Module):
|
280 |
+
def __init__(self, config):
|
281 |
+
super().__init__()
|
282 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
283 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
284 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
285 |
+
|
286 |
+
def forward(self, hidden_states, input_tensor):
|
287 |
+
hidden_states = self.dense(hidden_states)
|
288 |
+
hidden_states = self.dropout(hidden_states)
|
289 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
290 |
+
return hidden_states
|
291 |
+
|
292 |
+
|
293 |
+
class BertAttention(nn.Module):
|
294 |
+
def __init__(self, config, is_cross_attention=False):
|
295 |
+
super().__init__()
|
296 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
297 |
+
self.output = BertSelfOutput(config)
|
298 |
+
self.pruned_heads = set()
|
299 |
+
|
300 |
+
def prune_heads(self, heads):
|
301 |
+
if len(heads) == 0:
|
302 |
+
return
|
303 |
+
heads, index = find_pruneable_heads_and_indices(
|
304 |
+
heads,
|
305 |
+
self.self.num_attention_heads,
|
306 |
+
self.self.attention_head_size,
|
307 |
+
self.pruned_heads,
|
308 |
+
)
|
309 |
+
|
310 |
+
# Prune linear layers
|
311 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
312 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
313 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
314 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
315 |
+
|
316 |
+
# Update hyper params and store pruned heads
|
317 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
318 |
+
self.self.all_head_size = (
|
319 |
+
self.self.attention_head_size * self.self.num_attention_heads
|
320 |
+
)
|
321 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
322 |
+
|
323 |
+
def forward(
|
324 |
+
self,
|
325 |
+
hidden_states,
|
326 |
+
attention_mask=None,
|
327 |
+
head_mask=None,
|
328 |
+
encoder_hidden_states=None,
|
329 |
+
encoder_attention_mask=None,
|
330 |
+
past_key_value=None,
|
331 |
+
output_attentions=False,
|
332 |
+
):
|
333 |
+
self_outputs = self.self(
|
334 |
+
hidden_states,
|
335 |
+
attention_mask,
|
336 |
+
head_mask,
|
337 |
+
encoder_hidden_states,
|
338 |
+
encoder_attention_mask,
|
339 |
+
past_key_value,
|
340 |
+
output_attentions,
|
341 |
+
)
|
342 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
343 |
+
|
344 |
+
outputs = (attention_output,) + self_outputs[
|
345 |
+
1:
|
346 |
+
] # add attentions if we output them
|
347 |
+
return outputs
|
348 |
+
|
349 |
+
|
350 |
+
class BertIntermediate(nn.Module):
|
351 |
+
def __init__(self, config):
|
352 |
+
super().__init__()
|
353 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
354 |
+
if isinstance(config.hidden_act, str):
|
355 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
356 |
+
else:
|
357 |
+
self.intermediate_act_fn = config.hidden_act
|
358 |
+
|
359 |
+
def forward(self, hidden_states):
|
360 |
+
hidden_states = self.dense(hidden_states)
|
361 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
362 |
+
return hidden_states
|
363 |
+
|
364 |
+
|
365 |
+
class BertOutput(nn.Module):
|
366 |
+
def __init__(self, config):
|
367 |
+
super().__init__()
|
368 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
369 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
370 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
371 |
+
|
372 |
+
def forward(self, hidden_states, input_tensor):
|
373 |
+
hidden_states = self.dense(hidden_states)
|
374 |
+
hidden_states = self.dropout(hidden_states)
|
375 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
376 |
+
return hidden_states
|
377 |
+
|
378 |
+
|
379 |
+
class BertLayer(nn.Module):
|
380 |
+
def __init__(self, config, layer_num):
|
381 |
+
super().__init__()
|
382 |
+
self.config = config
|
383 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
384 |
+
self.seq_len_dim = 1
|
385 |
+
self.attention = BertAttention(config)
|
386 |
+
self.layer_num = layer_num
|
387 |
+
if (
|
388 |
+
self.config.add_cross_attention
|
389 |
+
and layer_num % self.config.cross_attention_freq == 0
|
390 |
+
):
|
391 |
+
self.crossattention = BertAttention(
|
392 |
+
config, is_cross_attention=self.config.add_cross_attention
|
393 |
+
)
|
394 |
+
self.has_cross_attention = True
|
395 |
+
else:
|
396 |
+
self.has_cross_attention = False
|
397 |
+
self.intermediate = BertIntermediate(config)
|
398 |
+
self.output = BertOutput(config)
|
399 |
+
|
400 |
+
self.intermediate_query = BertIntermediate(config)
|
401 |
+
self.output_query = BertOutput(config)
|
402 |
+
|
403 |
+
def forward(
|
404 |
+
self,
|
405 |
+
hidden_states,
|
406 |
+
attention_mask=None,
|
407 |
+
head_mask=None,
|
408 |
+
encoder_hidden_states=None,
|
409 |
+
encoder_attention_mask=None,
|
410 |
+
past_key_value=None,
|
411 |
+
output_attentions=False,
|
412 |
+
query_length=0,
|
413 |
+
):
|
414 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
415 |
+
self_attn_past_key_value = (
|
416 |
+
past_key_value[:2] if past_key_value is not None else None
|
417 |
+
)
|
418 |
+
self_attention_outputs = self.attention(
|
419 |
+
hidden_states,
|
420 |
+
attention_mask,
|
421 |
+
head_mask,
|
422 |
+
output_attentions=output_attentions,
|
423 |
+
past_key_value=self_attn_past_key_value,
|
424 |
+
)
|
425 |
+
attention_output = self_attention_outputs[0]
|
426 |
+
outputs = self_attention_outputs[1:-1]
|
427 |
+
|
428 |
+
present_key_value = self_attention_outputs[-1]
|
429 |
+
|
430 |
+
if query_length > 0:
|
431 |
+
query_attention_output = attention_output[:, :query_length, :]
|
432 |
+
|
433 |
+
if self.has_cross_attention:
|
434 |
+
assert (
|
435 |
+
encoder_hidden_states is not None
|
436 |
+
), "encoder_hidden_states must be given for cross-attention layers"
|
437 |
+
cross_attention_outputs = self.crossattention(
|
438 |
+
query_attention_output,
|
439 |
+
attention_mask,
|
440 |
+
head_mask,
|
441 |
+
encoder_hidden_states,
|
442 |
+
encoder_attention_mask,
|
443 |
+
output_attentions=output_attentions,
|
444 |
+
)
|
445 |
+
query_attention_output = cross_attention_outputs[0]
|
446 |
+
outputs = (
|
447 |
+
outputs + cross_attention_outputs[1:-1]
|
448 |
+
) # add cross attentions if we output attention weights
|
449 |
+
|
450 |
+
layer_output = apply_chunking_to_forward(
|
451 |
+
self.feed_forward_chunk_query,
|
452 |
+
self.chunk_size_feed_forward,
|
453 |
+
self.seq_len_dim,
|
454 |
+
query_attention_output,
|
455 |
+
)
|
456 |
+
if attention_output.shape[1] > query_length:
|
457 |
+
layer_output_text = apply_chunking_to_forward(
|
458 |
+
self.feed_forward_chunk,
|
459 |
+
self.chunk_size_feed_forward,
|
460 |
+
self.seq_len_dim,
|
461 |
+
attention_output[:, query_length:, :],
|
462 |
+
)
|
463 |
+
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
464 |
+
else:
|
465 |
+
layer_output = apply_chunking_to_forward(
|
466 |
+
self.feed_forward_chunk,
|
467 |
+
self.chunk_size_feed_forward,
|
468 |
+
self.seq_len_dim,
|
469 |
+
attention_output,
|
470 |
+
)
|
471 |
+
outputs = (layer_output,) + outputs
|
472 |
+
|
473 |
+
outputs = outputs + (present_key_value,)
|
474 |
+
|
475 |
+
return outputs
|
476 |
+
|
477 |
+
def feed_forward_chunk(self, attention_output):
|
478 |
+
intermediate_output = self.intermediate(attention_output)
|
479 |
+
layer_output = self.output(intermediate_output, attention_output)
|
480 |
+
return layer_output
|
481 |
+
|
482 |
+
def feed_forward_chunk_query(self, attention_output):
|
483 |
+
intermediate_output = self.intermediate_query(attention_output)
|
484 |
+
layer_output = self.output_query(intermediate_output, attention_output)
|
485 |
+
return layer_output
|
486 |
+
|
487 |
+
|
488 |
+
class BertEncoder(nn.Module):
|
489 |
+
def __init__(self, config):
|
490 |
+
super().__init__()
|
491 |
+
self.config = config
|
492 |
+
self.layer = nn.ModuleList(
|
493 |
+
[BertLayer(config, i) for i in range(config.num_hidden_layers)]
|
494 |
+
)
|
495 |
+
|
496 |
+
def forward(
|
497 |
+
self,
|
498 |
+
hidden_states,
|
499 |
+
attention_mask=None,
|
500 |
+
head_mask=None,
|
501 |
+
encoder_hidden_states=None,
|
502 |
+
encoder_attention_mask=None,
|
503 |
+
past_key_values=None,
|
504 |
+
use_cache=None,
|
505 |
+
output_attentions=False,
|
506 |
+
output_hidden_states=False,
|
507 |
+
return_dict=True,
|
508 |
+
query_length=0,
|
509 |
+
):
|
510 |
+
all_hidden_states = () if output_hidden_states else None
|
511 |
+
all_self_attentions = () if output_attentions else None
|
512 |
+
all_cross_attentions = (
|
513 |
+
() if output_attentions and self.config.add_cross_attention else None
|
514 |
+
)
|
515 |
+
|
516 |
+
next_decoder_cache = () if use_cache else None
|
517 |
+
|
518 |
+
for i in range(self.config.num_hidden_layers):
|
519 |
+
layer_module = self.layer[i]
|
520 |
+
if output_hidden_states:
|
521 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
522 |
+
|
523 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
524 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
525 |
+
|
526 |
+
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
527 |
+
|
528 |
+
if use_cache:
|
529 |
+
logger.warn(
|
530 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
531 |
+
)
|
532 |
+
use_cache = False
|
533 |
+
|
534 |
+
def create_custom_forward(module):
|
535 |
+
def custom_forward(*inputs):
|
536 |
+
return module(
|
537 |
+
*inputs, past_key_value, output_attentions, query_length
|
538 |
+
)
|
539 |
+
|
540 |
+
return custom_forward
|
541 |
+
|
542 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
543 |
+
create_custom_forward(layer_module),
|
544 |
+
hidden_states,
|
545 |
+
attention_mask,
|
546 |
+
layer_head_mask,
|
547 |
+
encoder_hidden_states,
|
548 |
+
encoder_attention_mask,
|
549 |
+
)
|
550 |
+
else:
|
551 |
+
layer_outputs = layer_module(
|
552 |
+
hidden_states,
|
553 |
+
attention_mask,
|
554 |
+
layer_head_mask,
|
555 |
+
encoder_hidden_states,
|
556 |
+
encoder_attention_mask,
|
557 |
+
past_key_value,
|
558 |
+
output_attentions,
|
559 |
+
query_length,
|
560 |
+
)
|
561 |
+
|
562 |
+
hidden_states = layer_outputs[0]
|
563 |
+
if use_cache:
|
564 |
+
next_decoder_cache += (layer_outputs[-1],)
|
565 |
+
if output_attentions:
|
566 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
567 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
568 |
+
|
569 |
+
if output_hidden_states:
|
570 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
571 |
+
|
572 |
+
if not return_dict:
|
573 |
+
return tuple(
|
574 |
+
v
|
575 |
+
for v in [
|
576 |
+
hidden_states,
|
577 |
+
next_decoder_cache,
|
578 |
+
all_hidden_states,
|
579 |
+
all_self_attentions,
|
580 |
+
all_cross_attentions,
|
581 |
+
]
|
582 |
+
if v is not None
|
583 |
+
)
|
584 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
585 |
+
last_hidden_state=hidden_states,
|
586 |
+
past_key_values=next_decoder_cache,
|
587 |
+
hidden_states=all_hidden_states,
|
588 |
+
attentions=all_self_attentions,
|
589 |
+
cross_attentions=all_cross_attentions,
|
590 |
+
)
|
591 |
+
|
592 |
+
|
593 |
+
class BertPooler(nn.Module):
|
594 |
+
def __init__(self, config):
|
595 |
+
super().__init__()
|
596 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
597 |
+
self.activation = nn.Tanh()
|
598 |
+
|
599 |
+
def forward(self, hidden_states):
|
600 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
601 |
+
# to the first token.
|
602 |
+
first_token_tensor = hidden_states[:, 0]
|
603 |
+
pooled_output = self.dense(first_token_tensor)
|
604 |
+
pooled_output = self.activation(pooled_output)
|
605 |
+
return pooled_output
|
606 |
+
|
607 |
+
|
608 |
+
class BertPredictionHeadTransform(nn.Module):
|
609 |
+
def __init__(self, config):
|
610 |
+
super().__init__()
|
611 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
612 |
+
if isinstance(config.hidden_act, str):
|
613 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
614 |
+
else:
|
615 |
+
self.transform_act_fn = config.hidden_act
|
616 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
617 |
+
|
618 |
+
def forward(self, hidden_states):
|
619 |
+
hidden_states = self.dense(hidden_states)
|
620 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
621 |
+
hidden_states = self.LayerNorm(hidden_states)
|
622 |
+
return hidden_states
|
623 |
+
|
624 |
+
|
625 |
+
class BertLMPredictionHead(nn.Module):
|
626 |
+
def __init__(self, config):
|
627 |
+
super().__init__()
|
628 |
+
self.transform = BertPredictionHeadTransform(config)
|
629 |
+
|
630 |
+
# The output weights are the same as the input embeddings, but there is
|
631 |
+
# an output-only bias for each token.
|
632 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
633 |
+
|
634 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
635 |
+
|
636 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
637 |
+
self.decoder.bias = self.bias
|
638 |
+
|
639 |
+
def forward(self, hidden_states):
|
640 |
+
hidden_states = self.transform(hidden_states)
|
641 |
+
hidden_states = self.decoder(hidden_states)
|
642 |
+
return hidden_states
|
643 |
+
|
644 |
+
|
645 |
+
class BertOnlyMLMHead(nn.Module):
|
646 |
+
def __init__(self, config):
|
647 |
+
super().__init__()
|
648 |
+
self.predictions = BertLMPredictionHead(config)
|
649 |
+
|
650 |
+
def forward(self, sequence_output):
|
651 |
+
prediction_scores = self.predictions(sequence_output)
|
652 |
+
return prediction_scores
|
653 |
+
|
654 |
+
|
655 |
+
class BertPreTrainedModel(PreTrainedModel):
|
656 |
+
"""
|
657 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
658 |
+
models.
|
659 |
+
"""
|
660 |
+
|
661 |
+
config_class = BertConfig
|
662 |
+
base_model_prefix = "bert"
|
663 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
664 |
+
|
665 |
+
def _init_weights(self, module):
|
666 |
+
"""Initialize the weights"""
|
667 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
668 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
669 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
670 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
671 |
+
elif isinstance(module, nn.LayerNorm):
|
672 |
+
module.bias.data.zero_()
|
673 |
+
module.weight.data.fill_(1.0)
|
674 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
675 |
+
module.bias.data.zero_()
|
676 |
+
|
677 |
+
|
678 |
+
class BertModel(BertPreTrainedModel):
|
679 |
+
"""
|
680 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
681 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
682 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
683 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
684 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
685 |
+
input to the forward pass.
|
686 |
+
"""
|
687 |
+
|
688 |
+
def __init__(self, config, add_pooling_layer=False):
|
689 |
+
super().__init__(config)
|
690 |
+
self.config = config
|
691 |
+
|
692 |
+
self.embeddings = BertEmbeddings(config)
|
693 |
+
|
694 |
+
self.encoder = BertEncoder(config)
|
695 |
+
|
696 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
697 |
+
|
698 |
+
self.init_weights()
|
699 |
+
|
700 |
+
def get_input_embeddings(self):
|
701 |
+
return self.embeddings.word_embeddings
|
702 |
+
|
703 |
+
def set_input_embeddings(self, value):
|
704 |
+
self.embeddings.word_embeddings = value
|
705 |
+
|
706 |
+
def _prune_heads(self, heads_to_prune):
|
707 |
+
"""
|
708 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
709 |
+
class PreTrainedModel
|
710 |
+
"""
|
711 |
+
for layer, heads in heads_to_prune.items():
|
712 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
713 |
+
|
714 |
+
def get_extended_attention_mask(
|
715 |
+
self,
|
716 |
+
attention_mask: Tensor,
|
717 |
+
input_shape: Tuple[int],
|
718 |
+
device: device,
|
719 |
+
is_decoder: bool,
|
720 |
+
has_query: bool = False,
|
721 |
+
) -> Tensor:
|
722 |
+
"""
|
723 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
724 |
+
|
725 |
+
Arguments:
|
726 |
+
attention_mask (:obj:`torch.Tensor`):
|
727 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
728 |
+
input_shape (:obj:`Tuple[int]`):
|
729 |
+
The shape of the input to the model.
|
730 |
+
device: (:obj:`torch.device`):
|
731 |
+
The device of the input to the model.
|
732 |
+
|
733 |
+
Returns:
|
734 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
735 |
+
"""
|
736 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
737 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
738 |
+
if attention_mask.dim() == 3:
|
739 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
740 |
+
elif attention_mask.dim() == 2:
|
741 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
742 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
743 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
744 |
+
if is_decoder:
|
745 |
+
batch_size, seq_length = input_shape
|
746 |
+
|
747 |
+
seq_ids = torch.arange(seq_length, device=device)
|
748 |
+
causal_mask = (
|
749 |
+
seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
|
750 |
+
<= seq_ids[None, :, None]
|
751 |
+
)
|
752 |
+
|
753 |
+
# add a prefix ones mask to the causal mask
|
754 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
755 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
756 |
+
|
757 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
758 |
+
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
|
759 |
+
if has_query: # UniLM style attention mask
|
760 |
+
causal_mask = torch.cat(
|
761 |
+
[
|
762 |
+
torch.zeros(
|
763 |
+
(batch_size, prefix_seq_len, seq_length),
|
764 |
+
device=device,
|
765 |
+
dtype=causal_mask.dtype,
|
766 |
+
),
|
767 |
+
causal_mask,
|
768 |
+
],
|
769 |
+
axis=1,
|
770 |
+
)
|
771 |
+
causal_mask = torch.cat(
|
772 |
+
[
|
773 |
+
torch.ones(
|
774 |
+
(batch_size, causal_mask.shape[1], prefix_seq_len),
|
775 |
+
device=device,
|
776 |
+
dtype=causal_mask.dtype,
|
777 |
+
),
|
778 |
+
causal_mask,
|
779 |
+
],
|
780 |
+
axis=-1,
|
781 |
+
)
|
782 |
+
extended_attention_mask = (
|
783 |
+
causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
784 |
+
)
|
785 |
+
else:
|
786 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
787 |
+
else:
|
788 |
+
raise ValueError(
|
789 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
790 |
+
input_shape, attention_mask.shape
|
791 |
+
)
|
792 |
+
)
|
793 |
+
|
794 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
795 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
796 |
+
# positions we want to attend and -10000.0 for masked positions.
|
797 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
798 |
+
# effectively the same as removing these entirely.
|
799 |
+
extended_attention_mask = extended_attention_mask.to(
|
800 |
+
dtype=self.dtype
|
801 |
+
) # fp16 compatibility
|
802 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
803 |
+
return extended_attention_mask
|
804 |
+
|
805 |
+
def forward(
|
806 |
+
self,
|
807 |
+
input_ids=None,
|
808 |
+
attention_mask=None,
|
809 |
+
position_ids=None,
|
810 |
+
head_mask=None,
|
811 |
+
query_embeds=None,
|
812 |
+
encoder_hidden_states=None,
|
813 |
+
encoder_attention_mask=None,
|
814 |
+
past_key_values=None,
|
815 |
+
use_cache=None,
|
816 |
+
output_attentions=None,
|
817 |
+
output_hidden_states=None,
|
818 |
+
return_dict=None,
|
819 |
+
is_decoder=False,
|
820 |
+
):
|
821 |
+
r"""
|
822 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
823 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
824 |
+
the model is configured as a decoder.
|
825 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
826 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
827 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
828 |
+
- 1 for tokens that are **not masked**,
|
829 |
+
- 0 for tokens that are **masked**.
|
830 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
831 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
832 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
833 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
834 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
835 |
+
use_cache (:obj:`bool`, `optional`):
|
836 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
837 |
+
decoding (see :obj:`past_key_values`).
|
838 |
+
"""
|
839 |
+
output_attentions = (
|
840 |
+
output_attentions
|
841 |
+
if output_attentions is not None
|
842 |
+
else self.config.output_attentions
|
843 |
+
)
|
844 |
+
output_hidden_states = (
|
845 |
+
output_hidden_states
|
846 |
+
if output_hidden_states is not None
|
847 |
+
else self.config.output_hidden_states
|
848 |
+
)
|
849 |
+
return_dict = (
|
850 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
851 |
+
)
|
852 |
+
|
853 |
+
# use_cache = use_cache if use_cache is not None else self.config.use_cache
|
854 |
+
|
855 |
+
if input_ids is None:
|
856 |
+
assert (
|
857 |
+
query_embeds is not None
|
858 |
+
), "You have to specify query_embeds when input_ids is None"
|
859 |
+
|
860 |
+
# past_key_values_length
|
861 |
+
past_key_values_length = (
|
862 |
+
past_key_values[0][0].shape[2] - self.config.query_length
|
863 |
+
if past_key_values is not None
|
864 |
+
else 0
|
865 |
+
)
|
866 |
+
|
867 |
+
query_length = query_embeds.shape[1] if query_embeds is not None else 0
|
868 |
+
|
869 |
+
embedding_output = self.embeddings(
|
870 |
+
input_ids=input_ids,
|
871 |
+
position_ids=position_ids,
|
872 |
+
query_embeds=query_embeds,
|
873 |
+
past_key_values_length=past_key_values_length,
|
874 |
+
)
|
875 |
+
|
876 |
+
input_shape = embedding_output.size()[:-1]
|
877 |
+
batch_size, seq_length = input_shape
|
878 |
+
device = embedding_output.device
|
879 |
+
|
880 |
+
if attention_mask is None:
|
881 |
+
attention_mask = torch.ones(
|
882 |
+
((batch_size, seq_length + past_key_values_length)), device=device
|
883 |
+
)
|
884 |
+
|
885 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
886 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
887 |
+
if is_decoder:
|
888 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
889 |
+
attention_mask,
|
890 |
+
input_ids.shape,
|
891 |
+
device,
|
892 |
+
is_decoder,
|
893 |
+
has_query=(query_embeds is not None),
|
894 |
+
)
|
895 |
+
else:
|
896 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
897 |
+
attention_mask, input_shape, device, is_decoder
|
898 |
+
)
|
899 |
+
|
900 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
901 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
902 |
+
if encoder_hidden_states is not None:
|
903 |
+
if type(encoder_hidden_states) == list:
|
904 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
|
905 |
+
0
|
906 |
+
].size()
|
907 |
+
else:
|
908 |
+
(
|
909 |
+
encoder_batch_size,
|
910 |
+
encoder_sequence_length,
|
911 |
+
_,
|
912 |
+
) = encoder_hidden_states.size()
|
913 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
914 |
+
|
915 |
+
if type(encoder_attention_mask) == list:
|
916 |
+
encoder_extended_attention_mask = [
|
917 |
+
self.invert_attention_mask(mask) for mask in encoder_attention_mask
|
918 |
+
]
|
919 |
+
elif encoder_attention_mask is None:
|
920 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
921 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
922 |
+
encoder_attention_mask
|
923 |
+
)
|
924 |
+
else:
|
925 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
926 |
+
encoder_attention_mask
|
927 |
+
)
|
928 |
+
else:
|
929 |
+
encoder_extended_attention_mask = None
|
930 |
+
|
931 |
+
# Prepare head mask if needed
|
932 |
+
# 1.0 in head_mask indicate we keep the head
|
933 |
+
# attention_probs has shape bsz x n_heads x N x N
|
934 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
935 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
936 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
937 |
+
|
938 |
+
encoder_outputs = self.encoder(
|
939 |
+
embedding_output,
|
940 |
+
attention_mask=extended_attention_mask,
|
941 |
+
head_mask=head_mask,
|
942 |
+
encoder_hidden_states=encoder_hidden_states,
|
943 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
944 |
+
past_key_values=past_key_values,
|
945 |
+
use_cache=use_cache,
|
946 |
+
output_attentions=output_attentions,
|
947 |
+
output_hidden_states=output_hidden_states,
|
948 |
+
return_dict=return_dict,
|
949 |
+
query_length=query_length,
|
950 |
+
)
|
951 |
+
sequence_output = encoder_outputs[0]
|
952 |
+
pooled_output = (
|
953 |
+
self.pooler(sequence_output) if self.pooler is not None else None
|
954 |
+
)
|
955 |
+
|
956 |
+
if not return_dict:
|
957 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
958 |
+
|
959 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
960 |
+
last_hidden_state=sequence_output,
|
961 |
+
pooler_output=pooled_output,
|
962 |
+
past_key_values=encoder_outputs.past_key_values,
|
963 |
+
hidden_states=encoder_outputs.hidden_states,
|
964 |
+
attentions=encoder_outputs.attentions,
|
965 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
966 |
+
)
|
967 |
+
|
968 |
+
|
969 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
970 |
+
|
971 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
972 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
973 |
+
|
974 |
+
def __init__(self, config):
|
975 |
+
super().__init__(config)
|
976 |
+
|
977 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
978 |
+
self.cls = BertOnlyMLMHead(config)
|
979 |
+
|
980 |
+
self.init_weights()
|
981 |
+
|
982 |
+
def get_output_embeddings(self):
|
983 |
+
return self.cls.predictions.decoder
|
984 |
+
|
985 |
+
def set_output_embeddings(self, new_embeddings):
|
986 |
+
self.cls.predictions.decoder = new_embeddings
|
987 |
+
|
988 |
+
def forward(
|
989 |
+
self,
|
990 |
+
input_ids=None,
|
991 |
+
attention_mask=None,
|
992 |
+
position_ids=None,
|
993 |
+
head_mask=None,
|
994 |
+
query_embeds=None,
|
995 |
+
encoder_hidden_states=None,
|
996 |
+
encoder_attention_mask=None,
|
997 |
+
labels=None,
|
998 |
+
past_key_values=None,
|
999 |
+
use_cache=True,
|
1000 |
+
output_attentions=None,
|
1001 |
+
output_hidden_states=None,
|
1002 |
+
return_dict=None,
|
1003 |
+
return_logits=False,
|
1004 |
+
is_decoder=True,
|
1005 |
+
reduction="mean",
|
1006 |
+
):
|
1007 |
+
r"""
|
1008 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
1009 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
1010 |
+
the model is configured as a decoder.
|
1011 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1012 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
1013 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
1014 |
+
- 1 for tokens that are **not masked**,
|
1015 |
+
- 0 for tokens that are **masked**.
|
1016 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1017 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
1018 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
1019 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
1020 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
1021 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
1022 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
1023 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
1024 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
1025 |
+
use_cache (:obj:`bool`, `optional`):
|
1026 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
1027 |
+
decoding (see :obj:`past_key_values`).
|
1028 |
+
Returns:
|
1029 |
+
Example::
|
1030 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
1031 |
+
>>> import torch
|
1032 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
1033 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
1034 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
1035 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
1036 |
+
>>> outputs = model(**inputs)
|
1037 |
+
>>> prediction_logits = outputs.logits
|
1038 |
+
"""
|
1039 |
+
return_dict = (
|
1040 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1041 |
+
)
|
1042 |
+
if labels is not None:
|
1043 |
+
use_cache = False
|
1044 |
+
if past_key_values is not None:
|
1045 |
+
query_embeds = None
|
1046 |
+
|
1047 |
+
outputs = self.bert(
|
1048 |
+
input_ids,
|
1049 |
+
attention_mask=attention_mask,
|
1050 |
+
position_ids=position_ids,
|
1051 |
+
head_mask=head_mask,
|
1052 |
+
query_embeds=query_embeds,
|
1053 |
+
encoder_hidden_states=encoder_hidden_states,
|
1054 |
+
encoder_attention_mask=encoder_attention_mask,
|
1055 |
+
past_key_values=past_key_values,
|
1056 |
+
use_cache=use_cache,
|
1057 |
+
output_attentions=output_attentions,
|
1058 |
+
output_hidden_states=output_hidden_states,
|
1059 |
+
return_dict=return_dict,
|
1060 |
+
is_decoder=is_decoder,
|
1061 |
+
)
|
1062 |
+
|
1063 |
+
sequence_output = outputs[0]
|
1064 |
+
if query_embeds is not None:
|
1065 |
+
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
|
1066 |
+
|
1067 |
+
prediction_scores = self.cls(sequence_output)
|
1068 |
+
|
1069 |
+
if return_logits:
|
1070 |
+
return prediction_scores[:, :-1, :].contiguous()
|
1071 |
+
|
1072 |
+
lm_loss = None
|
1073 |
+
if labels is not None:
|
1074 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
1075 |
+
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
1076 |
+
labels = labels[:, 1:].contiguous()
|
1077 |
+
loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
|
1078 |
+
lm_loss = loss_fct(
|
1079 |
+
shifted_prediction_scores.view(-1, self.config.vocab_size),
|
1080 |
+
labels.view(-1),
|
1081 |
+
)
|
1082 |
+
if reduction == "none":
|
1083 |
+
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
|
1084 |
+
|
1085 |
+
if not return_dict:
|
1086 |
+
output = (prediction_scores,) + outputs[2:]
|
1087 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
1088 |
+
|
1089 |
+
return CausalLMOutputWithCrossAttentions(
|
1090 |
+
loss=lm_loss,
|
1091 |
+
logits=prediction_scores,
|
1092 |
+
past_key_values=outputs.past_key_values,
|
1093 |
+
hidden_states=outputs.hidden_states,
|
1094 |
+
attentions=outputs.attentions,
|
1095 |
+
cross_attentions=outputs.cross_attentions,
|
1096 |
+
)
|
1097 |
+
|
1098 |
+
def prepare_inputs_for_generation(
|
1099 |
+
self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
|
1100 |
+
):
|
1101 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
1102 |
+
if attention_mask is None:
|
1103 |
+
attention_mask = input_ids.new_ones(input_ids.shape)
|
1104 |
+
query_mask = input_ids.new_ones(query_embeds.shape[:-1])
|
1105 |
+
attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
|
1106 |
+
|
1107 |
+
# cut decoder_input_ids if past is used
|
1108 |
+
if past is not None:
|
1109 |
+
input_ids = input_ids[:, -1:]
|
1110 |
+
|
1111 |
+
return {
|
1112 |
+
"input_ids": input_ids,
|
1113 |
+
"query_embeds": query_embeds,
|
1114 |
+
"attention_mask": attention_mask,
|
1115 |
+
"past_key_values": past,
|
1116 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
1117 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
1118 |
+
"is_decoder": True,
|
1119 |
+
}
|
1120 |
+
|
1121 |
+
def _reorder_cache(self, past, beam_idx):
|
1122 |
+
reordered_past = ()
|
1123 |
+
for layer_past in past:
|
1124 |
+
reordered_past += (
|
1125 |
+
tuple(
|
1126 |
+
past_state.index_select(0, beam_idx) for past_state in layer_past
|
1127 |
+
),
|
1128 |
+
)
|
1129 |
+
return reordered_past
|
1130 |
+
|
1131 |
+
|
1132 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
1133 |
+
|
1134 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1135 |
+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
1136 |
+
|
1137 |
+
def __init__(self, config):
|
1138 |
+
super().__init__(config)
|
1139 |
+
|
1140 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
1141 |
+
self.cls = BertOnlyMLMHead(config)
|
1142 |
+
|
1143 |
+
self.init_weights()
|
1144 |
+
|
1145 |
+
def get_output_embeddings(self):
|
1146 |
+
return self.cls.predictions.decoder
|
1147 |
+
|
1148 |
+
def set_output_embeddings(self, new_embeddings):
|
1149 |
+
self.cls.predictions.decoder = new_embeddings
|
1150 |
+
|
1151 |
+
def forward(
|
1152 |
+
self,
|
1153 |
+
input_ids=None,
|
1154 |
+
attention_mask=None,
|
1155 |
+
position_ids=None,
|
1156 |
+
head_mask=None,
|
1157 |
+
query_embeds=None,
|
1158 |
+
encoder_hidden_states=None,
|
1159 |
+
encoder_attention_mask=None,
|
1160 |
+
labels=None,
|
1161 |
+
output_attentions=None,
|
1162 |
+
output_hidden_states=None,
|
1163 |
+
return_dict=None,
|
1164 |
+
return_logits=False,
|
1165 |
+
is_decoder=False,
|
1166 |
+
):
|
1167 |
+
r"""
|
1168 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1169 |
+
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
1170 |
+
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
1171 |
+
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
1172 |
+
"""
|
1173 |
+
|
1174 |
+
return_dict = (
|
1175 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
1176 |
+
)
|
1177 |
+
|
1178 |
+
outputs = self.bert(
|
1179 |
+
input_ids,
|
1180 |
+
attention_mask=attention_mask,
|
1181 |
+
position_ids=position_ids,
|
1182 |
+
head_mask=head_mask,
|
1183 |
+
query_embeds=query_embeds,
|
1184 |
+
encoder_hidden_states=encoder_hidden_states,
|
1185 |
+
encoder_attention_mask=encoder_attention_mask,
|
1186 |
+
output_attentions=output_attentions,
|
1187 |
+
output_hidden_states=output_hidden_states,
|
1188 |
+
return_dict=return_dict,
|
1189 |
+
is_decoder=is_decoder,
|
1190 |
+
)
|
1191 |
+
|
1192 |
+
if query_embeds is not None:
|
1193 |
+
sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
|
1194 |
+
prediction_scores = self.cls(sequence_output)
|
1195 |
+
|
1196 |
+
if return_logits:
|
1197 |
+
return prediction_scores
|
1198 |
+
|
1199 |
+
masked_lm_loss = None
|
1200 |
+
if labels is not None:
|
1201 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
1202 |
+
masked_lm_loss = loss_fct(
|
1203 |
+
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
|
1204 |
+
)
|
1205 |
+
|
1206 |
+
if not return_dict:
|
1207 |
+
output = (prediction_scores,) + outputs[2:]
|
1208 |
+
return (
|
1209 |
+
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
1210 |
+
)
|
1211 |
+
|
1212 |
+
return MaskedLMOutput(
|
1213 |
+
loss=masked_lm_loss,
|
1214 |
+
logits=prediction_scores,
|
1215 |
+
hidden_states=outputs.hidden_states,
|
1216 |
+
attentions=outputs.attentions,
|
1217 |
+
)
|
models/salmonn.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (2023) Tsinghua University, Bytedance Ltd. and/or its affiliates
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# Modifications to allow device parallelism.
|
16 |
+
# Copyright (2024) William Held
|
17 |
+
# Same Terms as as above apply
|
18 |
+
|
19 |
+
|
20 |
+
import librosa
|
21 |
+
import soundfile as sf
|
22 |
+
import torch
|
23 |
+
import torch.nn as nn
|
24 |
+
import torch.nn.functional as F
|
25 |
+
from peft import LoraConfig, TaskType, get_peft_model
|
26 |
+
from transformers import (
|
27 |
+
LlamaForCausalLM,
|
28 |
+
LlamaTokenizer,
|
29 |
+
WhisperFeatureExtractor,
|
30 |
+
WhisperModel,
|
31 |
+
)
|
32 |
+
|
33 |
+
from models.beats.BEATs import BEATs, BEATsConfig
|
34 |
+
from models.qformer.Qformer import BertConfig, BertLMHeadModel
|
35 |
+
|
36 |
+
|
37 |
+
class SALMONN(nn.Module):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
ckpt,
|
41 |
+
whisper_path,
|
42 |
+
beats_path,
|
43 |
+
vicuna_path,
|
44 |
+
speech_qformer_token_num=1,
|
45 |
+
speech_qformer_layer=2,
|
46 |
+
lora=True,
|
47 |
+
device="cuda:0",
|
48 |
+
lora_alpha=32,
|
49 |
+
lora_rank=8,
|
50 |
+
lora_dropout=0.1,
|
51 |
+
second_per_frame=0.333333,
|
52 |
+
second_stride=0.333333,
|
53 |
+
low_resource=False,
|
54 |
+
):
|
55 |
+
|
56 |
+
super().__init__()
|
57 |
+
|
58 |
+
# feature_extractor
|
59 |
+
self.feature_extractor = WhisperFeatureExtractor.from_pretrained(whisper_path)
|
60 |
+
|
61 |
+
# whisper
|
62 |
+
self.speech_encoder = WhisperModel.from_pretrained(whisper_path).encoder.to(
|
63 |
+
device
|
64 |
+
)
|
65 |
+
self.ln_speech = nn.LayerNorm(self.speech_encoder.config.d_model).to(device)
|
66 |
+
|
67 |
+
# beats
|
68 |
+
self.beats_ckpt = beats_path
|
69 |
+
beats_checkpoint = torch.load(self.beats_ckpt, map_location=device)
|
70 |
+
beats_cfg = BEATsConfig(beats_checkpoint["cfg"])
|
71 |
+
beats = BEATs(beats_cfg)
|
72 |
+
beats.load_state_dict(beats_checkpoint["model"])
|
73 |
+
self.beats = beats
|
74 |
+
self.beats.to(device)
|
75 |
+
self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim).to(device)
|
76 |
+
for name, param in self.beats.named_parameters():
|
77 |
+
param.requires_grad = False
|
78 |
+
self.beats.eval()
|
79 |
+
|
80 |
+
# init speech Qformer
|
81 |
+
self.speech_Qformer, self.speech_query_tokens = self.init_speech_Qformer(
|
82 |
+
speech_qformer_token_num,
|
83 |
+
self.speech_encoder.config.d_model + self.beats.cfg.encoder_embed_dim,
|
84 |
+
speech_qformer_layer,
|
85 |
+
)
|
86 |
+
self.speech_Qformer.to(device)
|
87 |
+
self.speech_query_tokens.to(device)
|
88 |
+
self.second_per_frame = second_per_frame
|
89 |
+
self.second_stride = second_stride
|
90 |
+
|
91 |
+
# vicuna
|
92 |
+
if not low_resource:
|
93 |
+
self.llama_model = LlamaForCausalLM.from_pretrained(
|
94 |
+
vicuna_path,
|
95 |
+
torch_dtype=torch.float16,
|
96 |
+
device_map="auto",
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
self.llama_model = LlamaForCausalLM.from_pretrained(
|
100 |
+
vicuna_path,
|
101 |
+
torch_dtype=torch.float16,
|
102 |
+
load_in_8bit=True,
|
103 |
+
device_map="auto",
|
104 |
+
)
|
105 |
+
|
106 |
+
# lora
|
107 |
+
self.lora = lora
|
108 |
+
if lora:
|
109 |
+
target_modules = None
|
110 |
+
self.peft_config = LoraConfig(
|
111 |
+
task_type=TaskType.CAUSAL_LM,
|
112 |
+
inference_mode=True,
|
113 |
+
r=lora_rank,
|
114 |
+
lora_alpha=lora_alpha,
|
115 |
+
lora_dropout=lora_dropout,
|
116 |
+
target_modules=target_modules,
|
117 |
+
)
|
118 |
+
self.llama_model = get_peft_model(self.llama_model, self.peft_config)
|
119 |
+
|
120 |
+
# tokenizer
|
121 |
+
self.llama_tokenizer = LlamaTokenizer.from_pretrained(
|
122 |
+
vicuna_path, use_fast=False
|
123 |
+
)
|
124 |
+
self.llama_tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
125 |
+
self.llama_tokenizer.padding_side = "right"
|
126 |
+
|
127 |
+
# proj
|
128 |
+
self.speech_llama_proj = nn.Linear(
|
129 |
+
self.speech_Qformer.config.hidden_size, self.llama_model.config.hidden_size
|
130 |
+
).to(device)
|
131 |
+
|
132 |
+
# load ckpt
|
133 |
+
ckpt_dict = torch.load(ckpt)["model"]
|
134 |
+
self.load_state_dict(ckpt_dict, strict=False)
|
135 |
+
|
136 |
+
def generate(
|
137 |
+
self,
|
138 |
+
wav_path,
|
139 |
+
prompt,
|
140 |
+
prompt_pattern="USER: <Speech><SpeechHere></Speech> {}\nASSISTANT:",
|
141 |
+
device="cuda:0",
|
142 |
+
max_length=200,
|
143 |
+
max_new_tokens=128,
|
144 |
+
num_beams=1,
|
145 |
+
do_sample=True,
|
146 |
+
min_length=1,
|
147 |
+
top_p=0.9,
|
148 |
+
repetition_penalty=1.0,
|
149 |
+
length_penalty=1.0,
|
150 |
+
temperature=1.0,
|
151 |
+
logits_processor=None,
|
152 |
+
streamer=None
|
153 |
+
):
|
154 |
+
# read wav
|
155 |
+
wav, sr = sf.read(wav_path)
|
156 |
+
if len(wav.shape) == 2:
|
157 |
+
wav = wav[:, 0]
|
158 |
+
if len(wav) > 30 * sr:
|
159 |
+
wav = wav[: 30 * sr]
|
160 |
+
if sr != 16000:
|
161 |
+
wav = librosa.resample(wav, orig_sr=sr, target_sr=16000, res_type="fft")
|
162 |
+
|
163 |
+
# whisper
|
164 |
+
spectrogram = self.feature_extractor(
|
165 |
+
wav, return_tensors="pt", sampling_rate=16000
|
166 |
+
).input_features.to(
|
167 |
+
device
|
168 |
+
) # [1, 80, 3000]
|
169 |
+
speech_embeds = self.speech_encoder(
|
170 |
+
spectrogram, return_dict=True
|
171 |
+
).last_hidden_state
|
172 |
+
|
173 |
+
# beats
|
174 |
+
raw_wav = torch.from_numpy(wav).to(device).unsqueeze(0)
|
175 |
+
audio_padding_mask = torch.zeros(raw_wav.shape, device=device).bool()
|
176 |
+
audio_embeds, _ = self.beats.extract_features(
|
177 |
+
raw_wav, padding_mask=audio_padding_mask, feature_only=True
|
178 |
+
)
|
179 |
+
|
180 |
+
# auditory embeds
|
181 |
+
speech_embeds = self.ln_speech(speech_embeds)
|
182 |
+
audio_embeds = self.ln_audio(audio_embeds)
|
183 |
+
audio_embeds = F.pad(
|
184 |
+
audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))
|
185 |
+
)
|
186 |
+
speech_embeds = torch.cat([speech_embeds, audio_embeds], dim=-1)
|
187 |
+
|
188 |
+
# split frames
|
189 |
+
B, T, C = speech_embeds.shape
|
190 |
+
kernel = round(T * self.second_per_frame / 30.0)
|
191 |
+
stride = round(T * self.second_stride / 30.0)
|
192 |
+
kernel = (1, kernel)
|
193 |
+
stride = (1, stride)
|
194 |
+
speech_embeds_tr = speech_embeds.transpose(1, 2).unsqueeze(2)
|
195 |
+
speech_embeds_overlap = F.unfold(
|
196 |
+
speech_embeds_tr, kernel_size=kernel, dilation=1, padding=0, stride=stride
|
197 |
+
)
|
198 |
+
_, _, L = speech_embeds_overlap.shape
|
199 |
+
speech_embeds_overlap = speech_embeds_overlap.view(B, -1, kernel[1], L)
|
200 |
+
speech_embeds_overlap = torch.permute(speech_embeds_overlap, [0, 3, 2, 1])
|
201 |
+
speech_embeds = speech_embeds_overlap.reshape(-1, kernel[1], C)
|
202 |
+
speech_atts = torch.ones(
|
203 |
+
speech_embeds.size()[:-1], dtype=torch.long, device=speech_embeds.device
|
204 |
+
)
|
205 |
+
|
206 |
+
# Qformer
|
207 |
+
query_tokens = self.speech_query_tokens.expand(speech_embeds.shape[0], -1, -1)
|
208 |
+
query_output = self.speech_Qformer.bert(
|
209 |
+
query_embeds=query_tokens.to(device),
|
210 |
+
encoder_hidden_states=speech_embeds,
|
211 |
+
encoder_attention_mask=speech_atts,
|
212 |
+
return_dict=True,
|
213 |
+
)
|
214 |
+
speech_embeds = self.speech_llama_proj(query_output.last_hidden_state)
|
215 |
+
speech_embeds = speech_embeds.view(B, -1, speech_embeds.size(2)).contiguous()
|
216 |
+
speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long).to(
|
217 |
+
speech_embeds.device
|
218 |
+
)
|
219 |
+
|
220 |
+
# USER: <Speech>speech_embeds<Speech> prompt\nASSISTANT:
|
221 |
+
embed_tokens = (
|
222 |
+
self.llama_model.model.model.embed_tokens
|
223 |
+
if self.lora
|
224 |
+
else self.llama_model.model.embed_tokens
|
225 |
+
)
|
226 |
+
prompt_left, prompts_right = prompt_pattern.format(prompt).split("<SpeechHere>")
|
227 |
+
prompt_left_ids = (
|
228 |
+
self.llama_tokenizer(
|
229 |
+
prompt_left, return_tensors="pt", add_special_tokens=False
|
230 |
+
)
|
231 |
+
.to(speech_embeds.device)
|
232 |
+
.input_ids
|
233 |
+
)
|
234 |
+
prompt_left_embeds = embed_tokens(prompt_left_ids)
|
235 |
+
prompt_right_ids = (
|
236 |
+
self.llama_tokenizer(
|
237 |
+
prompts_right, return_tensors="pt", add_special_tokens=False
|
238 |
+
)
|
239 |
+
.to(speech_embeds.device)
|
240 |
+
.input_ids
|
241 |
+
)
|
242 |
+
prompt_right_embeds = embed_tokens(prompt_right_ids)
|
243 |
+
|
244 |
+
bos_embeds = (
|
245 |
+
self.llama_model.model.embed_tokens(
|
246 |
+
torch.ones(
|
247 |
+
[1, 1],
|
248 |
+
dtype=torch.long,
|
249 |
+
device=device,
|
250 |
+
)
|
251 |
+
* self.llama_tokenizer.bos_token_id
|
252 |
+
)
|
253 |
+
if not self.lora
|
254 |
+
else self.llama_model.model.model.embed_tokens(
|
255 |
+
torch.ones(
|
256 |
+
[1, 1],
|
257 |
+
dtype=torch.long,
|
258 |
+
device=device,
|
259 |
+
)
|
260 |
+
* self.llama_tokenizer.bos_token_id
|
261 |
+
)
|
262 |
+
)
|
263 |
+
|
264 |
+
embed_list = [bos_embeds, prompt_left_embeds, speech_embeds, prompt_right_embeds]
|
265 |
+
embeds = torch.cat(
|
266 |
+
[embed.to(bos_embeds.device) for embed in embed_list], dim=1
|
267 |
+
)
|
268 |
+
atts = torch.ones(embeds.size()[:-1], dtype=torch.long).to(embeds.device)
|
269 |
+
# generate
|
270 |
+
output = self.llama_model.generate(
|
271 |
+
inputs_embeds=embeds,
|
272 |
+
max_length=max_length,
|
273 |
+
max_new_tokens=max_new_tokens,
|
274 |
+
num_beams=num_beams,
|
275 |
+
do_sample=do_sample,
|
276 |
+
min_length=min_length,
|
277 |
+
top_p=top_p,
|
278 |
+
repetition_penalty=repetition_penalty,
|
279 |
+
length_penalty=length_penalty,
|
280 |
+
temperature=temperature,
|
281 |
+
attention_mask=atts,
|
282 |
+
bos_token_id=self.llama_tokenizer.bos_token_id,
|
283 |
+
eos_token_id=self.llama_tokenizer.eos_token_id,
|
284 |
+
pad_token_id=self.llama_tokenizer.pad_token_id,
|
285 |
+
logits_processor=[logits_processor] if logits_processor != None else None,
|
286 |
+
streamer=streamer,
|
287 |
+
)
|
288 |
+
|
289 |
+
output_text = self.llama_tokenizer.batch_decode(
|
290 |
+
output, add_special_tokens=False, skip_special_tokens=True
|
291 |
+
)
|
292 |
+
|
293 |
+
return output_text
|
294 |
+
|
295 |
+
def init_speech_Qformer(self, num_query_token, speech_width, num_hidden_layers=2):
|
296 |
+
encoder_config = BertConfig()
|
297 |
+
encoder_config.num_hidden_layers = num_hidden_layers
|
298 |
+
encoder_config.encoder_width = speech_width
|
299 |
+
encoder_config.add_cross_attention = True
|
300 |
+
encoder_config.cross_attention_freq = 1
|
301 |
+
encoder_config.query_length = num_query_token
|
302 |
+
Qformer = BertLMHeadModel(config=encoder_config)
|
303 |
+
query_tokens = nn.Parameter(
|
304 |
+
torch.zeros(1, num_query_token, encoder_config.hidden_size)
|
305 |
+
)
|
306 |
+
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
|
307 |
+
return Qformer, query_tokens
|