AlexHT_Hung
commited on
Commit
•
2582d7c
1
Parent(s):
1192f59
update
Browse files- audio_processing_mllama.py +1 -1
- config.json +4 -4
- mllama_audio_model.py +47 -6
audio_processing_mllama.py
CHANGED
@@ -20,7 +20,7 @@ def build_audio_tokens(text: List[str], audio_features: Union[Dict, List[List[np
|
|
20 |
return text
|
21 |
|
22 |
def get_num_embeddings(num_framses, adapter_kernel_size=7, adapter_stride=4) -> int:
|
23 |
-
return math.ceil((num_framses - adapter_kernel_size) / adapter_stride) + 1 + 2 # 2 = <|begin_of_audio|>, <|end_of_audio|>
|
24 |
|
25 |
|
26 |
class MllamaAudioFeatureExtractor(SeamlessM4TFeatureExtractor):
|
|
|
20 |
return text
|
21 |
|
22 |
def get_num_embeddings(num_framses, adapter_kernel_size=7, adapter_stride=4) -> int:
|
23 |
+
return math.ceil((num_framses - adapter_kernel_size + adapter_stride) / adapter_stride) + 1 + 2 # 2 = <|begin_of_audio|>, <|end_of_audio|>
|
24 |
|
25 |
|
26 |
class MllamaAudioFeatureExtractor(SeamlessM4TFeatureExtractor):
|
config.json
CHANGED
@@ -142,10 +142,10 @@
|
|
142 |
},
|
143 |
"audio_token_index": 128257,
|
144 |
"auto_map": {
|
145 |
-
"AutoConfig": "
|
146 |
-
"AutoModel": "
|
147 |
-
"AutoProcessor": "
|
148 |
-
"AutoFeatureExtractor": "
|
149 |
},
|
150 |
"image_token_index": 128256,
|
151 |
"model_type": "llama3",
|
|
|
142 |
},
|
143 |
"audio_token_index": 128257,
|
144 |
"auto_map": {
|
145 |
+
"AutoConfig": "configuration_llama3.Llama3Config",
|
146 |
+
"AutoModel": "modeling_llama3.Llama3ForConditionalGeneration",
|
147 |
+
"AutoProcessor": "processing_mllama.MllamaProcessor",
|
148 |
+
"AutoFeatureExtractor": "audio_processing_mllama.MllamaAudioFeatureExtractor"
|
149 |
},
|
150 |
"image_token_index": 128256,
|
151 |
"model_type": "llama3",
|
mllama_audio_model.py
CHANGED
@@ -1,22 +1,62 @@
|
|
1 |
from typing import Optional, Tuple, Union
|
2 |
-
import math
|
3 |
import torch
|
4 |
from torch import nn
|
5 |
from transformers.modeling_outputs import BaseModelOutput
|
6 |
-
from transformers import Wav2Vec2BertModel, Wav2Vec2BertConfig,
|
7 |
-
from transformers.models.
|
8 |
from .configuration_llama3 import Llama3Config
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
class Llama3Embedding(MllamaPreTrainedModel):
|
12 |
config_class = Llama3Config
|
13 |
base_model_prefix = "audio_model"
|
14 |
def __init__(self, config: Llama3Config):
|
15 |
super().__init__(config)
|
16 |
-
assert config.audio_config.add_adapter is True, f'{type(self).__name__} requires add adapter to be true.'
|
17 |
assert config.audio_config.output_hidden_size == config.text_config.hidden_size
|
18 |
self.text_embeddings = nn.Embedding(config.text_config.vocab_size, config.text_config.hidden_size, config.text_config.pad_token_id)
|
19 |
-
|
|
|
|
|
20 |
self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.audio_config.output_hidden_size)), requires_grad=True)
|
21 |
self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.audio_config.output_hidden_size)), requires_grad=True)
|
22 |
self.text_config = config.text_config
|
@@ -30,7 +70,8 @@ class Llama3Embedding(MllamaPreTrainedModel):
|
|
30 |
if audio_features is None:
|
31 |
return input_embeddings
|
32 |
bs, max_num_img, l, d = audio_features.shape
|
33 |
-
audio_embeddings = self.
|
|
|
34 |
audio_embeddings = audio_embeddings.view((bs, max_num_img, -1, self.start_of_audio.shape[-1]))
|
35 |
|
36 |
for i in range(bs):
|
|
|
1 |
from typing import Optional, Tuple, Union
|
|
|
2 |
import torch
|
3 |
from torch import nn
|
4 |
from transformers.modeling_outputs import BaseModelOutput
|
5 |
+
from transformers import Wav2Vec2BertModel, Wav2Vec2BertConfig, MllamaPreTrainedModel
|
6 |
+
from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import Wav2Vec2BertAdapterLayer
|
7 |
from .configuration_llama3 import Llama3Config
|
8 |
|
9 |
+
class AudioAdapter(nn.Module):
|
10 |
+
def __init__(self, config: Wav2Vec2BertConfig):
|
11 |
+
super().__init__()
|
12 |
+
# feature dim might need to be down-projected
|
13 |
+
if config.output_hidden_size != config.hidden_size:
|
14 |
+
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
|
15 |
+
else:
|
16 |
+
self.proj = None
|
17 |
+
self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(config) for _ in range(config.num_adapter_layers))
|
18 |
+
self.layerdrop = config.layerdrop
|
19 |
|
20 |
+
self.kernel_size = config.adapter_kernel_size
|
21 |
+
self.stride = config.adapter_stride
|
22 |
+
|
23 |
+
def _compute_sub_sample_lengths_from_attention_mask(self, seq_lens):
|
24 |
+
if seq_lens is None:
|
25 |
+
return seq_lens
|
26 |
+
pad = self.kernel_size // 2
|
27 |
+
seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1
|
28 |
+
return seq_lens.floor()
|
29 |
+
|
30 |
+
def forward(self, hidden_states, attention_mask=None):
|
31 |
+
# down project hidden_states if necessary
|
32 |
+
if self.proj is not None:
|
33 |
+
hidden_states = self.proj(hidden_states)
|
34 |
+
|
35 |
+
sub_sampled_lengths = None
|
36 |
+
if attention_mask is not None:
|
37 |
+
sub_sampled_lengths = (attention_mask.size(1) - (1 - attention_mask.int()).sum(1)).to(hidden_states.device)
|
38 |
+
|
39 |
+
for layer in self.layers:
|
40 |
+
layerdrop_prob = torch.rand([])
|
41 |
+
sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(sub_sampled_lengths)
|
42 |
+
if not self.training or (layerdrop_prob > self.layerdrop):
|
43 |
+
hidden_states = layer(
|
44 |
+
hidden_states, attention_mask=attention_mask, sub_sampled_lengths=sub_sampled_lengths
|
45 |
+
)
|
46 |
+
|
47 |
+
return hidden_states
|
48 |
+
|
49 |
+
|
50 |
class Llama3Embedding(MllamaPreTrainedModel):
|
51 |
config_class = Llama3Config
|
52 |
base_model_prefix = "audio_model"
|
53 |
def __init__(self, config: Llama3Config):
|
54 |
super().__init__(config)
|
|
|
55 |
assert config.audio_config.output_hidden_size == config.text_config.hidden_size
|
56 |
self.text_embeddings = nn.Embedding(config.text_config.vocab_size, config.text_config.hidden_size, config.text_config.pad_token_id)
|
57 |
+
config.audio_config.add_adapter = False
|
58 |
+
self.audio_encoder = Wav2Vec2BertModel(config.audio_config)
|
59 |
+
self.audio_adapter = AudioAdapter(config.audio_config)
|
60 |
self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.audio_config.output_hidden_size)), requires_grad=True)
|
61 |
self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.audio_config.output_hidden_size)), requires_grad=True)
|
62 |
self.text_config = config.text_config
|
|
|
70 |
if audio_features is None:
|
71 |
return input_embeddings
|
72 |
bs, max_num_img, l, d = audio_features.shape
|
73 |
+
audio_embeddings = self.audio_encoder(input_features=audio_features.view((bs*max_num_img, l, d)))['last_hidden_state']
|
74 |
+
audio_embeddings = self.audio_adapter(audio_embeddings)
|
75 |
audio_embeddings = audio_embeddings.view((bs, max_num_img, -1, self.start_of_audio.shape[-1]))
|
76 |
|
77 |
for i in range(bs):
|