import torch import torch.nn as nn import torch.nn.functional as F from transformers.models.seamless_m4t.modeling_seamless_m4t import ( _compute_new_attention_mask, ) from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2 import ( SeamlessM4Tv2SpeechEncoder, SeamlessM4Tv2PreTrainedModel, ) from .configuration_seamless_m4t_v2_speech_encoder import ( MODEL_TYPE, SeamlessM4Tv2EncoderConfig, ) from transformers.modeling_outputs import SequenceClassifierOutput from transformers.models.auto import ( AutoModel, AutoModelForAudioClassification, AutoModelForSequenceClassification, ) class SeamlessM4Tv2SpeechEncoder(SeamlessM4Tv2SpeechEncoder): model_type = MODEL_TYPE config_class = SeamlessM4Tv2EncoderConfig def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @staticmethod def mean_pooling( hidden_states: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: # hidden_states shape: (batch_size, sequence_length, hidden_size) # attention_mask shape: (batch_size, sequence_length) # Apply attention mask and avoid division by zero input_mask_expanded = ( attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() ) sum_hidden_states = torch.sum(hidden_states * input_mask_expanded, 1) sum_mask = input_mask_expanded.sum(1) return sum_hidden_states / torch.clamp(sum_mask, min=1e-9) class SeamlessM4Tv2ForAudioClassification(SeamlessM4Tv2PreTrainedModel): model_type = MODEL_TYPE base_model_prefix = "model" config_class = SeamlessM4Tv2EncoderConfig def __init__(self, config, *args, **kwargs): super().__init__(config) self.num_labels = config.num_labels self.model = SeamlessM4Tv2SpeechEncoder(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) def forward( self, input_features: torch.Tensor, attention_mask: torch.Tensor, labels: None | torch.Tensor, *args, **kwargs, ): output_hidden_states = kwargs.pop("output_hidden_states", False) outputs = self.model( input_features, attention_mask, output_hidden_states=output_hidden_states, *args, **kwargs, ) hidden_states = outputs.last_hidden_state if attention_mask is not None: sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask( attention_mask ).to(outputs.last_hidden_state.device) attention_mask = _compute_new_attention_mask( hidden_states=hidden_states, seq_lens=sub_sampled_lengths ) hidden_states = self.model.mean_pooling( outputs.last_hidden_state, attention_mask ) logits = self.score(hidden_states) if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and ( labels.dtype == torch.long or labels.dtype == torch.int ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = F.mse_loss if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = F.cross_entropy loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = F.binary_cross_entropy_with_logits loss = loss_fct(logits, labels) return SequenceClassifierOutput( loss=loss, # type: ignore logits=logits, hidden_states=outputs.hidden_states if output_hidden_states else None, ) AutoModel.register(SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2SpeechEncoder) AutoModelForAudioClassification.register( SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2ForAudioClassification ) AutoModelForSequenceClassification.register( SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2ForAudioClassification )