seamless-m4t-v2-large-speech-encoder / modeling_seamless_m4t_v2_speech_encoder.py
fdschmidt93's picture
initial commit
65a0eff
raw
history blame
3.85 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.seamless_m4t.modeling_seamless_m4t import (
_compute_new_attention_mask,
)
from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2 import (
SeamlessM4Tv2SpeechEncoder,
SeamlessM4Tv2PreTrainedModel,
)
from .configuration_seamless_m4t_v2_speech_encoder import (
MODEL_TYPE,
SeamlessM4Tv2EncoderConfig,
)
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.auto import AutoModel, AutoModelForAudioClassification, AutoModelForSequenceClassification
class SeamlessM4Tv2SpeechEncoder(SeamlessM4Tv2SpeechEncoder):
model_type = MODEL_TYPE
config_class = SeamlessM4Tv2EncoderConfig
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask):
pad = self.kernel_size // 2
seq_lens = attention_mask.size(1) - (1 - attention_mask.int()).sum(1)
seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1
return seq_lens.floor()
@staticmethod
def mean_pooling(
hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
# hidden_states shape: (batch_size, sequence_length, hidden_size)
# attention_mask shape: (batch_size, sequence_length)
# Apply attention mask and avoid division by zero
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
)
sum_hidden_states = torch.sum(hidden_states * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
return sum_hidden_states / torch.clamp(sum_mask, min=1e-9)
class SeamlessM4Tv2ForAudioClassification(SeamlessM4Tv2PreTrainedModel):
model_type = MODEL_TYPE
base_model_prefix = "model"
config_class = SeamlessM4Tv2EncoderConfig
def __init__(self, config, *args, **kwargs):
super().__init__(config)
self.num_labels = config.num_labels
self.model = SeamlessM4Tv2SpeechEncoder(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
def forward(
self,
input_features: torch.Tensor,
attention_mask: torch.Tensor,
labels: None | torch.Tensor,
*args,
**kwargs,
):
output_hidden_states = kwargs.pop("output_hidden_states", False)
outputs = self.model(
input_features,
attention_mask,
output_hidden_states=output_hidden_states,
*args,
**kwargs,
)
hidden_states = outputs.last_hidden_state
if attention_mask is not None:
sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask(
attention_mask
).to(outputs.last_hidden_state.device)
attention_mask = _compute_new_attention_mask(
hidden_states=hidden_states, seq_lens=sub_sampled_lengths
)
hidden_states = self.model.mean_pooling(
outputs.last_hidden_state, attention_mask
)
logits = self.score(hidden_states)
if labels is not None:
loss = F.cross_entropy(logits, labels)
else:
loss = None
return SequenceClassifierOutput(
loss=loss, # type: ignore
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
)
AutoModel.register(SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2SpeechEncoder)
AutoModelForAudioClassification.register(
SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2ForAudioClassification
)
AutoModelForSequenceClassification.register(
SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2ForAudioClassification
)