AlexHT_Hung commited on
Commit
2582d7c
1 Parent(s): 1192f59
audio_processing_mllama.py CHANGED
@@ -20,7 +20,7 @@ def build_audio_tokens(text: List[str], audio_features: Union[Dict, List[List[np
20
  return text
21
 
22
  def get_num_embeddings(num_framses, adapter_kernel_size=7, adapter_stride=4) -> int:
23
- return math.ceil((num_framses - adapter_kernel_size) / adapter_stride) + 1 + 2 # 2 = <|begin_of_audio|>, <|end_of_audio|>
24
 
25
 
26
  class MllamaAudioFeatureExtractor(SeamlessM4TFeatureExtractor):
 
20
  return text
21
 
22
  def get_num_embeddings(num_framses, adapter_kernel_size=7, adapter_stride=4) -> int:
23
+ return math.ceil((num_framses - adapter_kernel_size + adapter_stride) / adapter_stride) + 1 + 2 # 2 = <|begin_of_audio|>, <|end_of_audio|>
24
 
25
 
26
  class MllamaAudioFeatureExtractor(SeamlessM4TFeatureExtractor):
config.json CHANGED
@@ -142,10 +142,10 @@
142
  },
143
  "audio_token_index": 128257,
144
  "auto_map": {
145
- "AutoConfig": "AlexHung29629/test_mllama_11B_v3--configuration_llama3.Llama3Config",
146
- "AutoModel": "AlexHung29629/test_mllama_11B_v3--modeling_llama3.Llama3ForConditionalGeneration",
147
- "AutoProcessor": "AlexHung29629/test_mllama_11B_v3--processing_mllama.MllamaProcessor",
148
- "AutoFeatureExtractor": "AlexHung29629/test_mllama_11B_v3--audio_processing_mllama.MllamaAudioFeatureExtractor"
149
  },
150
  "image_token_index": 128256,
151
  "model_type": "llama3",
 
142
  },
143
  "audio_token_index": 128257,
144
  "auto_map": {
145
+ "AutoConfig": "configuration_llama3.Llama3Config",
146
+ "AutoModel": "modeling_llama3.Llama3ForConditionalGeneration",
147
+ "AutoProcessor": "processing_mllama.MllamaProcessor",
148
+ "AutoFeatureExtractor": "audio_processing_mllama.MllamaAudioFeatureExtractor"
149
  },
150
  "image_token_index": 128256,
151
  "model_type": "llama3",
mllama_audio_model.py CHANGED
@@ -1,22 +1,62 @@
1
  from typing import Optional, Tuple, Union
2
- import math
3
  import torch
4
  from torch import nn
5
  from transformers.modeling_outputs import BaseModelOutput
6
- from transformers import Wav2Vec2BertModel, Wav2Vec2BertConfig, Wav2Vec2BertPreTrainedModel, MllamaPreTrainedModel
7
- from transformers.models.mllama.configuration_mllama import MllamaTextConfig
8
  from .configuration_llama3 import Llama3Config
9
 
 
 
 
 
 
 
 
 
 
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class Llama3Embedding(MllamaPreTrainedModel):
12
  config_class = Llama3Config
13
  base_model_prefix = "audio_model"
14
  def __init__(self, config: Llama3Config):
15
  super().__init__(config)
16
- assert config.audio_config.add_adapter is True, f'{type(self).__name__} requires add adapter to be true.'
17
  assert config.audio_config.output_hidden_size == config.text_config.hidden_size
18
  self.text_embeddings = nn.Embedding(config.text_config.vocab_size, config.text_config.hidden_size, config.text_config.pad_token_id)
19
- self.audio_embedding = Wav2Vec2BertModel(config.audio_config)
 
 
20
  self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.audio_config.output_hidden_size)), requires_grad=True)
21
  self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.audio_config.output_hidden_size)), requires_grad=True)
22
  self.text_config = config.text_config
@@ -30,7 +70,8 @@ class Llama3Embedding(MllamaPreTrainedModel):
30
  if audio_features is None:
31
  return input_embeddings
32
  bs, max_num_img, l, d = audio_features.shape
33
- audio_embeddings = self.audio_embedding(input_features=audio_features.view((bs*max_num_img, l, d)))['last_hidden_state']
 
34
  audio_embeddings = audio_embeddings.view((bs, max_num_img, -1, self.start_of_audio.shape[-1]))
35
 
36
  for i in range(bs):
 
1
  from typing import Optional, Tuple, Union
 
2
  import torch
3
  from torch import nn
4
  from transformers.modeling_outputs import BaseModelOutput
5
+ from transformers import Wav2Vec2BertModel, Wav2Vec2BertConfig, MllamaPreTrainedModel
6
+ from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import Wav2Vec2BertAdapterLayer
7
  from .configuration_llama3 import Llama3Config
8
 
9
+ class AudioAdapter(nn.Module):
10
+ def __init__(self, config: Wav2Vec2BertConfig):
11
+ super().__init__()
12
+ # feature dim might need to be down-projected
13
+ if config.output_hidden_size != config.hidden_size:
14
+ self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
15
+ else:
16
+ self.proj = None
17
+ self.layers = nn.ModuleList(Wav2Vec2BertAdapterLayer(config) for _ in range(config.num_adapter_layers))
18
+ self.layerdrop = config.layerdrop
19
 
20
+ self.kernel_size = config.adapter_kernel_size
21
+ self.stride = config.adapter_stride
22
+
23
+ def _compute_sub_sample_lengths_from_attention_mask(self, seq_lens):
24
+ if seq_lens is None:
25
+ return seq_lens
26
+ pad = self.kernel_size // 2
27
+ seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1
28
+ return seq_lens.floor()
29
+
30
+ def forward(self, hidden_states, attention_mask=None):
31
+ # down project hidden_states if necessary
32
+ if self.proj is not None:
33
+ hidden_states = self.proj(hidden_states)
34
+
35
+ sub_sampled_lengths = None
36
+ if attention_mask is not None:
37
+ sub_sampled_lengths = (attention_mask.size(1) - (1 - attention_mask.int()).sum(1)).to(hidden_states.device)
38
+
39
+ for layer in self.layers:
40
+ layerdrop_prob = torch.rand([])
41
+ sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(sub_sampled_lengths)
42
+ if not self.training or (layerdrop_prob > self.layerdrop):
43
+ hidden_states = layer(
44
+ hidden_states, attention_mask=attention_mask, sub_sampled_lengths=sub_sampled_lengths
45
+ )
46
+
47
+ return hidden_states
48
+
49
+
50
  class Llama3Embedding(MllamaPreTrainedModel):
51
  config_class = Llama3Config
52
  base_model_prefix = "audio_model"
53
  def __init__(self, config: Llama3Config):
54
  super().__init__(config)
 
55
  assert config.audio_config.output_hidden_size == config.text_config.hidden_size
56
  self.text_embeddings = nn.Embedding(config.text_config.vocab_size, config.text_config.hidden_size, config.text_config.pad_token_id)
57
+ config.audio_config.add_adapter = False
58
+ self.audio_encoder = Wav2Vec2BertModel(config.audio_config)
59
+ self.audio_adapter = AudioAdapter(config.audio_config)
60
  self.start_of_audio = nn.Parameter(data=torch.zeros((1, config.audio_config.output_hidden_size)), requires_grad=True)
61
  self.end_of_audio = nn.Parameter(data=torch.zeros((1, config.audio_config.output_hidden_size)), requires_grad=True)
62
  self.text_config = config.text_config
 
70
  if audio_features is None:
71
  return input_embeddings
72
  bs, max_num_img, l, d = audio_features.shape
73
+ audio_embeddings = self.audio_encoder(input_features=audio_features.view((bs*max_num_img, l, d)))['last_hidden_state']
74
+ audio_embeddings = self.audio_adapter(audio_embeddings)
75
  audio_embeddings = audio_embeddings.view((bs, max_num_img, -1, self.start_of_audio.shape[-1]))
76
 
77
  for i in range(bs):