AlexHung29629
commited on
Commit
•
66a79e5
1
Parent(s):
a7d35c3
Update mllama_audio_model.py
Browse files- mllama_audio_model.py +1 -1
mllama_audio_model.py
CHANGED
@@ -15,7 +15,7 @@ class MllamaAudioModel(MllamaPreTrainedModel):
|
|
15 |
super().__init__(config)
|
16 |
assert config.add_adapter is True, f'{type(self).__name__} requires add adapter to be true.'
|
17 |
#assert config.output_hidden_size == text_embedding.weight.shape[1], f'Output hidden size({config.output_hidden_size}) of audio model and text embedding({text_embedding.weight.shape[1]}) must match!'
|
18 |
-
|
19 |
self.text_embedding = nn.Embedding(text_config.vocab_size + 8, text_config.hidden_size, text_config.pad_token_id)
|
20 |
self.audio_embedding = Wav2Vec2BertModel(config)
|
21 |
self.start_of_audio = nn.Parameter(data=torch.mean(text_embedding.weight, dim=0).unsqueeze(0), requires_grad=True)
|
|
|
15 |
super().__init__(config)
|
16 |
assert config.add_adapter is True, f'{type(self).__name__} requires add adapter to be true.'
|
17 |
#assert config.output_hidden_size == text_embedding.weight.shape[1], f'Output hidden size({config.output_hidden_size}) of audio model and text embedding({text_embedding.weight.shape[1]}) must match!'
|
18 |
+
assert config.output_hidden_size == text_config.hidden_size
|
19 |
self.text_embedding = nn.Embedding(text_config.vocab_size + 8, text_config.hidden_size, text_config.pad_token_id)
|
20 |
self.audio_embedding = Wav2Vec2BertModel(config)
|
21 |
self.start_of_audio = nn.Parameter(data=torch.mean(text_embedding.weight, dim=0).unsqueeze(0), requires_grad=True)
|