AlexHung29629 commited on
Commit
e8cfffd
1 Parent(s): d1affd5

Update mllama_audio_model.py

Browse files
Files changed (1) hide show
  1. mllama_audio_model.py +4 -5
mllama_audio_model.py CHANGED
@@ -35,12 +35,11 @@ class MllamaAudioModel(MllamaPreTrainedModel):
35
  for i in range(bs):
36
  for j in range(max_num_img):
37
  audio_id = -1 - j
38
- idx = torch.where(input_ids[i] == audio_id)
39
- if idx.numel() > 0:
40
- input_embeddings[i][idx] = torch.concat([self.start_of_audio, audio_features[i, j][idx], self.end_of_audio])
41
 
42
- idx = torch.where(input_ids < 0 and input_ids >= -max_num_img)
43
- input_ids[idx].fill_(self.filler_token_id)
44
 
45
  if return_dict:
46
  return dict(input_embeddings=input_embeddings)
 
35
  for i in range(bs):
36
  for j in range(max_num_img):
37
  audio_id = -1 - j
38
+ if torch.any(input_ids[i] == audio_id):
39
+ idx = input_ids[i] == audio_id
40
+ input_embeddings[i][idx] = torch.concat([self.start_of_audio, audio_features[i, j], self.end_of_audio])
41
 
42
+ input_ids[input_ids < 0].fill_(self.filler_token_id)
 
43
 
44
  if return_dict:
45
  return dict(input_embeddings=input_embeddings)