AlexHung29629
commited on
Commit
•
699ea16
1
Parent(s):
584669c
Update mllama_audio_model.py
Browse files- mllama_audio_model.py +3 -5
mllama_audio_model.py
CHANGED
@@ -29,18 +29,16 @@ class MllamaAudioModel(MllamaPreTrainedModel):
|
|
29 |
input_ids: torch.LongTensor = None,
|
30 |
return_dict: Optional[bool] = None,
|
31 |
) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
|
|
|
32 |
input_embeddings = self.text_embedding(torch.clamp(input_ids, min=0))
|
33 |
-
audio_embeddings = self.audio_embedding(input_features=audio_features
|
34 |
-
|
35 |
-
bs, max_num_img, _, _ = audio_features.shape
|
36 |
|
37 |
for i in range(bs):
|
38 |
for j in range(max_num_img):
|
39 |
audio_id = -1 - j
|
40 |
if torch.any(input_ids[i] == audio_id):
|
41 |
idx = input_ids[i] == audio_id
|
42 |
-
print(f"{audio_features[i, j].shape=}")
|
43 |
-
print(f"{self.start_of_audio.shape=}")
|
44 |
input_embeddings[i][idx] = torch.concat([self.start_of_audio, audio_embeddings[i, j], self.end_of_audio])
|
45 |
|
46 |
if return_dict:
|
|
|
29 |
input_ids: torch.LongTensor = None,
|
30 |
return_dict: Optional[bool] = None,
|
31 |
) -> Union[BaseModelOutput, Tuple[torch.Tensor, ...]]:
|
32 |
+
bs, max_num_img, l, d = audio_features.shape
|
33 |
input_embeddings = self.text_embedding(torch.clamp(input_ids, min=0))
|
34 |
+
audio_embeddings = self.audio_embedding(input_features=audio_features.view((bs*max_num_img, l, d)))['last_hidden_state']
|
35 |
+
audio_embeddings = audio_embeddings.view((bs, max_num_img, -1, start_of_audio.shape[-1]))
|
|
|
36 |
|
37 |
for i in range(bs):
|
38 |
for j in range(max_num_img):
|
39 |
audio_id = -1 - j
|
40 |
if torch.any(input_ids[i] == audio_id):
|
41 |
idx = input_ids[i] == audio_id
|
|
|
|
|
42 |
input_embeddings[i][idx] = torch.concat([self.start_of_audio, audio_embeddings[i, j], self.end_of_audio])
|
43 |
|
44 |
if return_dict:
|