import torch import torch.nn as nn import torch.nn.functional as F from transformers.models.seamless_m4t.modeling_seamless_m4t import ( _compute_new_attention_mask, ) from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2 import ( SeamlessM4Tv2SpeechEncoder, SeamlessM4Tv2PreTrainedModel, ) from .configuration_seamless_m4t_v2_speech_encoder import ( MODEL_TYPE, SeamlessM4Tv2EncoderConfig, ) from transformers.modeling_outputs import SequenceClassifierOutput from transformers.models.auto import AutoModel, AutoModelForAudioClassification, AutoModelForSequenceClassification class SeamlessM4Tv2SpeechEncoder(SeamlessM4Tv2SpeechEncoder): model_type = MODEL_TYPE config_class = SeamlessM4Tv2EncoderConfig def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): pad = self.kernel_size // 2 seq_lens = attention_mask.size(1) - (1 - attention_mask.int()).sum(1) seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1 return seq_lens.floor() @staticmethod def mean_pooling( hidden_states: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: # hidden_states shape: (batch_size, sequence_length, hidden_size) # attention_mask shape: (batch_size, sequence_length) # Apply attention mask and avoid division by zero input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() ) sum_hidden_states = torch.sum(hidden_states * input_mask_expanded, 1) sum_mask = input_mask_expanded.sum(1) return sum_hidden_states / torch.clamp(sum_mask, min=1e-9) class SeamlessM4Tv2ForAudioClassification(SeamlessM4Tv2PreTrainedModel): model_type = MODEL_TYPE base_model_prefix = "model" config_class = SeamlessM4Tv2EncoderConfig def __init__(self, config, *args, **kwargs): super().__init__(config) self.num_labels = config.num_labels self.model = SeamlessM4Tv2SpeechEncoder(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) def forward( self, input_features: torch.Tensor, attention_mask: torch.Tensor, labels: None | torch.Tensor, *args, **kwargs, ): output_hidden_states = kwargs.pop("output_hidden_states", False) outputs = self.model( input_features, attention_mask, output_hidden_states=output_hidden_states, *args, **kwargs, ) hidden_states = outputs.last_hidden_state if attention_mask is not None: sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask( attention_mask ).to(outputs.last_hidden_state.device) attention_mask = _compute_new_attention_mask( hidden_states=hidden_states, seq_lens=sub_sampled_lengths ) hidden_states = self.model.mean_pooling( outputs.last_hidden_state, attention_mask ) logits = self.score(hidden_states) if labels is not None: loss = F.cross_entropy(logits, labels) else: loss = None return SequenceClassifierOutput( loss=loss, # type: ignore logits=logits, hidden_states=outputs.hidden_states if output_hidden_states else None, ) AutoModel.register(SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2SpeechEncoder) AutoModelForAudioClassification.register( SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2ForAudioClassification ) AutoModelForSequenceClassification.register( SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2ForAudioClassification )