Update audio_text_multimodal.py
Browse files- audio_text_multimodal.py +17 -2
audio_text_multimodal.py
CHANGED
@@ -14,6 +14,14 @@ from transformers import (
|
|
14 |
Wav2Vec2Model
|
15 |
)
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
class MultiModalConfig(PretrainedConfig):
|
19 |
"""Base class for multimodal configs"""
|
@@ -170,7 +178,7 @@ class AudioTextModelForSequenceBaseClassification(BaseClassificationModel):
|
|
170 |
output_hidden_states=output_hidden_states,
|
171 |
return_dict=return_dict,
|
172 |
)
|
173 |
-
audio_mean = self.merged_strategy(audio_output.last_hidden_state, mode=
|
174 |
|
175 |
pooled_output = torch.cat(
|
176 |
(audio_mean, text_output.pooler_output), dim=1
|
@@ -205,6 +213,8 @@ class Wav2Vec2BertForSequenceClassification(AudioTextModelForSequenceBaseClassif
|
|
205 |
"""
|
206 |
def __init__(self, config):
|
207 |
super().__init__(config)
|
|
|
|
|
208 |
self.audio_config = Wav2Vec2Config.from_dict(self.config.Wav2Vec2Model)
|
209 |
self.text_config = BertConfig.from_dict(self.config.BertModel)
|
210 |
self.audio_model = Wav2Vec2Model(self.audio_config)
|
@@ -212,4 +222,9 @@ class Wav2Vec2BertForSequenceClassification(AudioTextModelForSequenceBaseClassif
|
|
212 |
self.classifier = torch.nn.Linear(
|
213 |
self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels
|
214 |
)
|
215 |
-
self.init_weights()
|
|
|
|
|
|
|
|
|
|
|
|
14 |
Wav2Vec2Model
|
15 |
)
|
16 |
|
17 |
+
from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
18 |
+
Wav2Vec2Encoder,
|
19 |
+
Wav2Vec2EncoderStableLayerNorm,
|
20 |
+
Wav2Vec2FeatureEncoder
|
21 |
+
)
|
22 |
+
|
23 |
+
from transformers.models.bert.modeling_bert import BertEncoder
|
24 |
+
|
25 |
|
26 |
class MultiModalConfig(PretrainedConfig):
|
27 |
"""Base class for multimodal configs"""
|
|
|
178 |
output_hidden_states=output_hidden_states,
|
179 |
return_dict=return_dict,
|
180 |
)
|
181 |
+
audio_mean = self.merged_strategy(audio_output.last_hidden_state, mode=self.config.pooling_mode)
|
182 |
|
183 |
pooled_output = torch.cat(
|
184 |
(audio_mean, text_output.pooler_output), dim=1
|
|
|
213 |
"""
|
214 |
def __init__(self, config):
|
215 |
super().__init__(config)
|
216 |
+
self.supports_gradient_checkpointing = getattr(config, "gradient_checkpointing", True)
|
217 |
+
|
218 |
self.audio_config = Wav2Vec2Config.from_dict(self.config.Wav2Vec2Model)
|
219 |
self.text_config = BertConfig.from_dict(self.config.BertModel)
|
220 |
self.audio_model = Wav2Vec2Model(self.audio_config)
|
|
|
222 |
self.classifier = torch.nn.Linear(
|
223 |
self.audio_config.hidden_size + self.text_config.hidden_size, self.num_labels
|
224 |
)
|
225 |
+
self.init_weights()
|
226 |
+
|
227 |
+
@staticmethod
|
228 |
+
def _set_gradient_checkpointing(module, value=False):
|
229 |
+
if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder, BertEncoder)):
|
230 |
+
module.gradient_checkpointing = value
|