|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers.models.seamless_m4t.modeling_seamless_m4t import ( |
|
_compute_new_attention_mask, |
|
) |
|
from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2 import ( |
|
SeamlessM4Tv2SpeechEncoder, |
|
SeamlessM4Tv2PreTrainedModel, |
|
) |
|
from .configuration_seamless_m4t_v2_speech_encoder import ( |
|
MODEL_TYPE, |
|
SeamlessM4Tv2EncoderConfig, |
|
) |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
from transformers.models.auto import AutoModel, AutoModelForAudioClassification, AutoModelForSequenceClassification |
|
|
|
|
|
class SeamlessM4Tv2SpeechEncoder(SeamlessM4Tv2SpeechEncoder): |
|
model_type = MODEL_TYPE |
|
config_class = SeamlessM4Tv2EncoderConfig |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def _compute_sub_sample_lengths_from_attention_mask(self, attention_mask): |
|
pad = self.kernel_size // 2 |
|
seq_lens = attention_mask.size(1) - (1 - attention_mask.int()).sum(1) |
|
|
|
seq_lens = ((seq_lens + 2 * pad - self.kernel_size) / self.stride) + 1 |
|
|
|
return seq_lens.floor() |
|
|
|
@staticmethod |
|
def mean_pooling( |
|
hidden_states: torch.Tensor, attention_mask: torch.Tensor |
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
input_mask_expanded = ( |
|
attention_mask.unsqueeze(-1).expand(hidden_states.size()).float() |
|
) |
|
sum_hidden_states = torch.sum(hidden_states * input_mask_expanded, 1) |
|
sum_mask = input_mask_expanded.sum(1) |
|
|
|
return sum_hidden_states / torch.clamp(sum_mask, min=1e-9) |
|
|
|
|
|
class SeamlessM4Tv2ForAudioClassification(SeamlessM4Tv2PreTrainedModel): |
|
model_type = MODEL_TYPE |
|
base_model_prefix = "model" |
|
config_class = SeamlessM4Tv2EncoderConfig |
|
|
|
def __init__(self, config, *args, **kwargs): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
self.model = SeamlessM4Tv2SpeechEncoder(config) |
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) |
|
|
|
def forward( |
|
self, |
|
input_features: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
labels: None | torch.Tensor, |
|
*args, |
|
**kwargs, |
|
): |
|
output_hidden_states = kwargs.pop("output_hidden_states", False) |
|
outputs = self.model( |
|
input_features, |
|
attention_mask, |
|
output_hidden_states=output_hidden_states, |
|
*args, |
|
**kwargs, |
|
) |
|
hidden_states = outputs.last_hidden_state |
|
if attention_mask is not None: |
|
sub_sampled_lengths = self._compute_sub_sample_lengths_from_attention_mask( |
|
attention_mask |
|
).to(outputs.last_hidden_state.device) |
|
attention_mask = _compute_new_attention_mask( |
|
hidden_states=hidden_states, seq_lens=sub_sampled_lengths |
|
) |
|
hidden_states = self.model.mean_pooling( |
|
outputs.last_hidden_state, attention_mask |
|
) |
|
logits = self.score(hidden_states) |
|
if labels is not None: |
|
loss = F.cross_entropy(logits, labels) |
|
else: |
|
loss = None |
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states if output_hidden_states else None, |
|
) |
|
|
|
|
|
AutoModel.register(SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2SpeechEncoder) |
|
AutoModelForAudioClassification.register( |
|
SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2ForAudioClassification |
|
) |
|
AutoModelForSequenceClassification.register( |
|
SeamlessM4Tv2EncoderConfig, SeamlessM4Tv2ForAudioClassification |
|
) |
|
|