AlexHung29629 commited on
Commit
699ea16
1 Parent(s): 584669c

Update mllama_audio_model.py

Browse files
Files changed (1) hide show
  1. mllama_audio_model.py +3 -5
mllama_audio_model.py CHANGED
@@ -29,18 +29,16 @@ class MllamaAudioModel(MllamaPreTrainedModel):
29
  input_ids: torch.LongTensor = None,
30
  return_dict: Optional[bool] = None,
31
  ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
 
32
  input_embeddings = self.text_embedding(torch.clamp(input_ids, min=0))
33
- audio_embeddings = self.audio_embedding(input_features=audio_features[0])['last_hidden_state']
34
- print(f"{audio_embeddings.shape=}")
35
- bs, max_num_img, _, _ = audio_features.shape
36
 
37
  for i in range(bs):
38
  for j in range(max_num_img):
39
  audio_id = -1 - j
40
  if torch.any(input_ids[i] == audio_id):
41
  idx = input_ids[i] == audio_id
42
- print(f"{audio_features[i, j].shape=}")
43
- print(f"{self.start_of_audio.shape=}")
44
  input_embeddings[i][idx] = torch.concat([self.start_of_audio, audio_embeddings[i, j], self.end_of_audio])
45
 
46
  if return_dict:
 
29
  input_ids: torch.LongTensor = None,
30
  return_dict: Optional[bool] = None,
31
  ) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
32
+ bs, max_num_img, l, d = audio_features.shape
33
  input_embeddings = self.text_embedding(torch.clamp(input_ids, min=0))
34
+ audio_embeddings = self.audio_embedding(input_features=audio_features.view((bs*max_num_img, l, d)))['last_hidden_state']
35
+ audio_embeddings = audio_embeddings.view((bs, max_num_img, -1, start_of_audio.shape[-1]))
 
36
 
37
  for i in range(bs):
38
  for j in range(max_num_img):
39
  audio_id = -1 - j
40
  if torch.any(input_ids[i] == audio_id):
41
  idx = input_ids[i] == audio_id
 
 
42
  input_embeddings[i][idx] = torch.concat([self.start_of_audio, audio_embeddings[i, j], self.end_of_audio])
43
 
44
  if return_dict: