seamless-m4t-v2-large-speech-encoder / modeling_seamless_m4t_v2_speech_encoder.py
fdschmidt93's picture
chore: formatting
ee8c043
raw
history blame
4.73 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.seamless_m4t.modeling_seamless_m4t import (
_compute_new_attention_mask,
)
from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2 import (
SeamlessM4Tv2SpeechEncoder,
SeamlessM4Tv2PreTrainedModel,
)
from .configuration_seamless_m4t_v2_speech_encoder import (
MODEL_TYPE,
SeamlessM4Tv2EncoderConfig,
)
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.auto import (
AutoModel,
AutoModelForAudioClassification,
AutoModelForSequenceClassification,
)
class SeamlessM4Tv2SpeechEncoder(SeamlessM4Tv2SpeechEncoder):
model_type = MODEL_TYPE
config_class = SeamlessM4Tv2EncoderConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@staticmethod
def mean_pooling(
hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
# hidden_states shape: (batch_size, sequence_length, hidden_size)
# attention_mask shape: (batch_size, sequence_length)
# Apply attention mask and avoid division by zero
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
)
sum_hidden_states = torch.sum(hidden_states * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
return sum_hidden_states / torch.clamp(sum_mask, min=1e-9)
class SeamlessM4Tv2ForAudioClassification(SeamlessM4Tv2PreTrainedModel):
model_type = MODEL_TYPE
base_model_prefix = "model"
config_class = SeamlessM4Tv2EncoderConfig
def __init__(self, config, *args, **kwargs):
super().__init__(config)
self.num_labels = config.num_labels
self.model = SeamlessM4Tv2SpeechEncoder(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
def forward(
self,
input_features: torch.Tensor,
attention_mask: torch.Tensor,
labels: None | torch.Tensor,
*args,
**kwargs,
):
output_hidden_states = kwargs.pop("output_hidden_states", False)
outputs = self.model(
input_features,
attention_mask,
output_hidden_states=output_hidden_states,
*args,
**kwargs,
)
hidden_states = outputs.last_hidden_state
if attention_mask is not None:
sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(
attention_mask
).to(outputs.last_hidden_state.device)
attention_mask = _compute_new_attention_mask(
hidden_states=hidden_states, seq_lens=sub_sampled_lengths
)
hidden_states = self.model.mean_pooling(
outputs.last_hidden_state, attention_mask
)
logits = self.score(hidden_states)
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = F.mse_loss
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = F.cross_entropy
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = F.binary_cross_entropy_with_logits
loss = loss_fct(logits, labels)
return SequenceClassifierOutput(
loss=loss, # type: ignore
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
)
AutoModel.register(SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2SpeechEncoder)
AutoModelForAudioClassification.register(
SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2ForAudioClassification
)
AutoModelForSequenceClassification.register(
SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2ForAudioClassification
)