AlexHung29629
commited on
Commit
•
e8cfffd
1
Parent(s):
d1affd5
Update mllama_audio_model.py
Browse files- mllama_audio_model.py +4 -5
mllama_audio_model.py
CHANGED
@@ -35,12 +35,11 @@ class MllamaAudioModel(MllamaPreTrainedModel):
|
|
35 |
for i in range(bs):
|
36 |
for j in range(max_num_img):
|
37 |
audio_id = -1 - j
|
38 |
-
|
39 |
-
|
40 |
-
input_embeddings[i][idx] = torch.concat([self.start_of_audio, audio_features[i, j]
|
41 |
|
42 |
-
|
43 |
-
input_ids[idx].fill_(self.filler_token_id)
|
44 |
|
45 |
if return_dict:
|
46 |
return dict(input_embeddings=input_embeddings)
|
|
|
35 |
for i in range(bs):
|
36 |
for j in range(max_num_img):
|
37 |
audio_id = -1 - j
|
38 |
+
if torch.any(input_ids[i] == audio_id):
|
39 |
+
idx = input_ids[i] == audio_id
|
40 |
+
input_embeddings[i][idx] = torch.concat([self.start_of_audio, audio_features[i, j], self.end_of_audio])
|
41 |
|
42 |
+
input_ids[input_ids < 0].fill_(self.filler_token_id)
|
|
|
43 |
|
44 |
if return_dict:
|
45 |
return dict(input_embeddings=input_embeddings)
|