wavlm-bert-fusion-s-emotion-russian-resd / audio_text_multimodal.py
Ar4ikov's picture
Update audio_text_multimodal.py
f85ba2b
from typing import Union, Type
import torch
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import (
PreTrainedModel,
PretrainedConfig,
WavLMConfig,
BertConfig,
WavLMModel,
BertModel,
Wav2Vec2Config,
Wav2Vec2Model
)
from transformers.models.wavlm.modeling_wavlm import (
WavLMEncoder,
WavLMEncoderStableLayerNorm,
WavLMFeatureEncoder
)
from transformers.models.bert.modeling_bert import BertEncoder
class MultiModalConfig(PretrainedConfig):
"""Base class for multimodal configs"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
class WavLMBertConfig(MultiModalConfig):
...
class BaseClassificationModel(PreTrainedModel):
config: Type[Union[PretrainedConfig, None]] = None
def compute_loss(self, logits, labels):
"""Compute loss
Args:
logits (torch.FloatTensor): logits
labels (torch.LongTensor): labels
Returns:
torch.FloatTensor: loss
Raises:
ValueError: Invalid number of labels
"""
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1:
self.config.problem_type = "single_label_classification"
else:
raise ValueError("Invalid number of labels: {}".format(self.num_labels))
if self.config.problem_type == "single_label_classification":
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = torch.nn.BCEWithLogitsLoss(weight=torch.tensor([1.4411, 2.1129, 0.9927, 1.6995, 0.9038, 0.4126, 1.4150]).to("cuda"))
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
elif self.config.problem_type == "regression":
loss_fct = torch.nn.MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
raise ValueError("Problem_type {} not supported".format(self.config.problem_type))
return loss
@staticmethod
def merged_strategy(
hidden_states,
mode="mean"
):
"""Merged strategy for pooling
Args:
hidden_states (torch.FloatTensor): hidden states
mode (str, optional): pooling mode. Defaults to "mean".
Returns:
torch.FloatTensor: pooled hidden states
"""
if mode == "mean":
outputs = torch.mean(hidden_states, dim=1)
elif mode == "sum":
outputs = torch.sum(hidden_states, dim=1)
elif mode == "max":
outputs = torch.max(hidden_states, dim=1)[0]
else:
raise Exception(
"The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
return outputs
class AudioTextModelForSequenceBaseClassification(BaseClassificationModel):
config_class = WavLMBertConfig
def __init__(self, config):
"""
Args:
config (MultiModalConfig): config
Attributes:
config (MultiModalConfig): config
num_labels (int): number of labels
audio_config (Union[PretrainedConfig, None]): audio config
text_config (Union[PretrainedConfig, None]): text config
audio_model (Union[PreTrainedModel, None]): audio model
text_model (Union[PreTrainedModel, None]): text model
classifier (Union[torch.nn.Linear, None]): classifier
"""
super().__init__(config)
self.config = config
self.num_labels = self.config.num_labels
self.audio_config: Union[PretrainedConfig, None] = None
self.text_config: Union[PretrainedConfig, None] = None
self.audio_model: Union[PreTrainedModel, None] = None
self.text_model: Union[PreTrainedModel, None] = None
self.classifier: Union[torch.nn.Linear, None] = None
class FusionModuleQ(torch.nn.Module):
def __init__(self, audio_dim, text_dim, num_heads, dropout=0.1):
super().__init__()
self.dimension = min(audio_dim, text_dim)
# attention modules
self.a_self_attention = torch.nn.MultiheadAttention(self.dimension, num_heads=num_heads)
self.t_self_attention = torch.nn.MultiheadAttention(self.dimension, num_heads=num_heads)
# layer norm
self.audio_norm = torch.nn.LayerNorm(self.dimension)
self.text_norm = torch.nn.LayerNorm(self.dimension)
def forward(self, audio_output, text_output):
# Multihead cross attention (dims ARE switched)
audio_attn, _ = self.a_self_attention(audio_output, text_output, text_output)
text_attn, _ = self.t_self_attention(text_output, audio_output, audio_output)
# Add & Norm with dropout
audio_add = self.audio_norm(audio_output + audio_attn)
text_add = self.text_norm(text_output + text_attn)
return audio_add, text_add
class AudioTextFusionModelForSequenceClassificaion(AudioTextModelForSequenceBaseClassification):
def __init__(self, config):
"""
Args:
config (MultiModalConfig): config
Attributes:
fusion_module_1 (FusionModuleQ): Fusion Module Q 1
fusion_module_2 (FusionModuleQ): Fusion Module Q 2
audio_projector (Union[torch.nn.Linear, None]): Projection layer for audio embeds
text_projector (Union[torch.nn.Linear, None]): Projection layer for text embeds
audio_avg_pool (Union[torch.nn.AvgPool1d, None]): Audio average pool (out from fusion block)
text_avg_pool (Union[torch.nn.AvgPool1d, None]): Text average pool (out from fusion block)
"""
super().__init__(config)
self.fusion_module_1: Union[FusionModuleQ, None] = None
self.fusion_module_2: Union[FusionModuleQ, None] = None
self.audio_projector: Union[torch.nn.Linear, None] = None
self.text_projector: Union[torch.nn.Linear, None] = None
self.audio_avg_pool: Union[torch.nn.AvgPool1d, None] = None
self.text_avg_pool: Union[torch.nn.AvgPool1d, None] = None
class WavLMBertForSequenceClassification(AudioTextFusionModelForSequenceClassificaion):
"""
WavLMBertForSequenceClassification is a model for sequence classification task
(e.g. sentiment analysis, text classification, etc.) for fine-tuning
Args:
config (WavLMBertConfig): config
Attributes:
config (WavLMBertConfig): config
audio_config (WavLMConfig): wavlm config
text_config (BertConfig): bert config
audio_model (WavLMModel): wavlm model
text_model (BertModel): bert model
fusion_module_1 (FusionModuleQ): Fusion Module Q 1
fusion_module_2 (FusionModuleQ): Fusion Module Q 2
audio_projector (Union[torch.nn.Linear, None]): Projection layer for audio embeds
text_projector (Union[torch.nn.Linear, None]): Projection layer for text embeds
audio_avg_pool (Union[torch.nn.AvgPool1d, None]): Audio average pool (out from fusion block)
text_avg_pool (Union[torch.nn.AvgPool1d, None]): Text average pool (out from fusion block)
classifier (torch.nn.Linear): classifier
"""
def __init__(self, config):
super().__init__(config)
self.supports_gradient_checkpointing = getattr(config, "gradient_checkpointing", True)
self.audio_config = WavLMConfig.from_dict(self.config.WavLMModel)
self.text_config = BertConfig.from_dict(self.config.BertModel)
self.audio_model = WavLMModel(self.audio_config)
self.text_model = BertModel(self.text_config)
# fusion module with V3 strategy (one projection on entry, no projection in continuous)
self.fusion_module_1 = FusionModuleQ(self.audio_config.hidden_size, self.text_config.hidden_size,
self.config.num_heads, self.config.f_dropout)
self.fusion_module_2 = FusionModuleQ(self.audio_config.hidden_size, self.text_config.hidden_size,
self.config.num_heads, self.config.f_dropout)
self.audio_projector = torch.nn.Linear(self.audio_config.hidden_size, self.text_config.hidden_size)
self.text_projector = torch.nn.Linear(self.text_config.hidden_size, self.text_config.hidden_size)
# Avg Pool
self.audio_avg_pool = torch.nn.AvgPool1d(self.config.kernel_size)
self.text_avg_pool = torch.nn.AvgPool1d(self.config.kernel_size)
# output dimensions of wav2vec2 and bert are 768 and 1024 respectively
cls_dim = min(self.audio_config.hidden_size, self.text_config.hidden_size)
self.classifier = torch.nn.Linear(
(cls_dim * 2) // self.config.kernel_size, self.config.num_labels
)
self.init_weights()
@staticmethod
def _set_gradient_checkpointing(module, value=False):
if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureEncoder, BertEncoder)):
module.gradient_checkpointing = value
def forward(
self,
input_ids=None,
input_values=None,
text_attention_mask=None,
audio_attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=True,
):
"""Forward method for multimodal model for sequence classification task (e.g. text + audio)
Args:
input_ids (torch.LongTensor, optional): input ids. Defaults to None.
input_values (torch.FloatTensor, optional): input values. Defaults to None.
text_attention_mask (torch.LongTensor, optional): text attention mask. Defaults to None.
audio_attention_mask (torch.LongTensor, optional): audio attention mask. Defaults to None.
token_type_ids (torch.LongTensor, optional): token type ids. Defaults to None.
position_ids (torch.LongTensor, optional): position ids. Defaults to None.
head_mask (torch.FloatTensor, optional): head mask. Defaults to None.
inputs_embeds (torch.FloatTensor, optional): inputs embeds. Defaults to None.
labels (torch.LongTensor, optional): labels. Defaults to None.
output_attentions (bool, optional): output attentions. Defaults to None.
output_hidden_states (bool, optional): output hidden states. Defaults to None.
return_dict (bool, optional): return dict. Defaults to True.
Returns:
torch.FloatTensor: logits
"""
audio_output = self.audio_model(
input_values=input_values,
attention_mask=audio_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
text_output = self.text_model(
input_ids=input_ids,
attention_mask=text_attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Mean pooling
audio_avg = self.merged_strategy(audio_output.last_hidden_state, mode=self.config.pooling_mode)
# Projection
audio_proj = self.audio_projector(audio_avg)
text_proj = self.text_projector(text_output.pooler_output)
audio_mha, text_mha = self.fusion_module_1(audio_proj, text_proj)
audio_mha, text_mha = self.fusion_module_2(audio_mha, text_mha)
audio_avg = self.audio_avg_pool(audio_mha)
text_avg = self.text_avg_pool(text_mha)
fusion_output = torch.concat((audio_avg, text_avg), dim=1)
logits = self.classifier(fusion_output)
loss = None
if labels is not None:
loss = self.compute_loss(logits, labels)
return SequenceClassifierOutput(
loss=loss,
logits=logits
)