import math import warnings from typing import Union, Tuple, Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput, Wav2Vec2BaseModelOutput from transformers.models.wavlm.modeling_wavlm import ( WavLMGumbelVectorQuantizer, WavLMPositionalConvEmbedding, WavLMFeatureProjection, WavLMFeatureEncoder, WavLMEncoderStableLayerNorm, WavLMEncoder, WavLMAdapter, _HIDDEN_STATES_START_POSITION ) from .configuration_wavlm_spkreg import WavLMSpkRegConfig def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, mask_length: int, attention_mask: Optional[torch.LongTensor] = None, min_masks: int = 0, ) -> np.ndarray: """ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on CPU as part of the preprocessing during training. Args: shape: The shape for which to compute masks. This should be of a tuple of size 2 where the first element is the batch size and the second element is the length of the axis to span. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of independently generated mask spans of length `mask_length` is computed by `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the actual percentage will be smaller. mask_length: size of the mask min_masks: minimum number of masked spans attention_mask: A (right-padded) attention mask which independently shortens the feature axis of each batch dimension. """ batch_size, sequence_length = shape if mask_length < 1: raise ValueError("`mask_length` has to be bigger than 0.") if mask_length > sequence_length: raise ValueError( f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" f" and `sequence_length`: {sequence_length}`" ) # epsilon is used for probabilistic rounding epsilon = np.random.rand(1).item() def compute_num_masked_span(input_length): """Given input length, compute how many spans should be masked""" num_masked_span = int(mask_prob * input_length / mask_length + epsilon) num_masked_span = max(num_masked_span, min_masks) # make sure num masked span <= sequence_length if num_masked_span * mask_length > sequence_length: num_masked_span = sequence_length // mask_length # make sure num_masked span is also <= input_length - (mask_length - 1) if input_length - (mask_length - 1) < num_masked_span: num_masked_span = max(input_length - (mask_length - 1), 0) return num_masked_span # compute number of masked spans in batch input_lengths = ( attention_mask.sum(-1).detach().tolist() if attention_mask is not None else [sequence_length for _ in range(batch_size)] ) # SpecAugment mask to fill spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) spec_aug_mask_idxs = [] max_num_masked_span = compute_num_masked_span(sequence_length) if max_num_masked_span == 0: return spec_aug_mask for input_length in input_lengths: # compute num of masked spans for this input num_masked_span = compute_num_masked_span(input_length) # get random indices to mask spec_aug_mask_idx = np.random.choice( np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False ) # pick first sampled index that will serve as a dummy index to pad vector # to ensure same dimension for all batches due to probabilistic rounding # Picking first sample just pads those vectors twice. if len(spec_aug_mask_idx) == 0: # this case can only happen if `input_length` is strictly smaller then # `sequence_length` in which case the last token has to be a padding # token which we can use as a dummy mask id dummy_mask_idx = sequence_length - 1 else: dummy_mask_idx = spec_aug_mask_idx[0] spec_aug_mask_idx = np.concatenate( [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] ) spec_aug_mask_idxs.append(spec_aug_mask_idx) spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) # expand masked indices to masked spans spec_aug_mask_idxs = np.broadcast_to( spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) ) spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) # add offset to the starting indexes so that indexes now create a span offsets = np.arange(mask_length)[None, None, :] offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( batch_size, max_num_masked_span * mask_length ) spec_aug_mask_idxs = spec_aug_mask_idxs + offsets # ensure that we cannot have indices larger than sequence_length if spec_aug_mask_idxs.max() > sequence_length - 1: spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 # scatter indices to mask np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) return spec_aug_mask class WavLMSpkRegPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = WavLMSpkRegConfig base_model_prefix = "wavlm" main_input_name = "input_values" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights""" # gumbel softmax requires special init if isinstance(module, WavLMGumbelVectorQuantizer): module.weight_proj.weight.data.normal_(mean=0.0, std=1) module.weight_proj.bias.data.zero_() nn.init.uniform_(module.codevectors) elif isinstance(module, WavLMPositionalConvEmbedding): nn.init.normal_( module.conv.weight, mean=0, std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) nn.init.constant_(module.conv.bias, 0) elif isinstance(module, WavLMFeatureProjection): k = math.sqrt(1 / module.projection.in_features) nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, nn.Conv1d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) def _get_feat_extract_output_lengths( self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None ): """ Computes the output length of the convolutional layers """ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter def _conv_out_length(input_length, kernel_size, stride): # 1D convolutional layer output length formula taken # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): input_lengths = _conv_out_length(input_lengths, kernel_size, stride) if add_adapter: for _ in range(self.config.num_adapter_layers): input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) return input_lengths def _get_feature_vector_attention_mask( self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None ): # Effectively attention_mask.sum(-1), but not inplace to be able to run # on inference mode. non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) output_lengths = output_lengths.to(torch.long) batch_size = attention_mask.shape[0] attention_mask = torch.zeros( (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device ) # these two operations makes sure that all values before the output lengths idxs are attended to attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() return attention_mask class WavLMSpkRegModel(WavLMSpkRegPreTrainedModel): def __init__(self, config: WavLMSpkRegConfig): super().__init__(config) self.config = config self.feature_extractor = WavLMFeatureEncoder(config) self.feature_projection = WavLMFeatureProjection(config) # model only needs masking vector if mask prob is > 0.0 if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) if config.do_stable_layer_norm: self.encoder = WavLMEncoderStableLayerNorm(config) else: self.encoder = WavLMEncoder(config) self.adapter = WavLMAdapter(config) if config.add_adapter else None # Initialize weights and apply final processing self.post_init() def freeze_feature_extractor(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameters will not be updated during training. """ warnings.warn( "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " "Please use the equivalent `freeze_feature_encoder` method instead.", FutureWarning, ) self.freeze_feature_encoder() def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will not be updated during training. """ self.feature_extractor._freeze_parameters() def _mask_hidden_states( self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): """ Masks extracted features along time axis and/or along feature axis according to [SpecAugment](https://arxiv.org/abs/1904.08779). """ # `config.apply_spec_augment` can set masking to False if not getattr(self.config, "apply_spec_augment", True): return hidden_states # generate indices & apply SpecAugment along time axis batch_size, sequence_length, hidden_size = hidden_states.size() if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), mask_prob=self.config.mask_time_prob, mask_length=self.config.mask_time_length, attention_mask=attention_mask, min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis mask_feature_indices = _compute_mask_indices( (batch_size, hidden_size), mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, min_masks=self.config.mask_feature_min_masks, ) mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) hidden_states[mask_feature_indices] = 0 return hidden_states def forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict extract_features = self.feature_extractor(input_values) extract_features = extract_features.transpose(1, 2) if attention_mask is not None: # compute reduced attention_mask corresponding to feature vectors attention_mask = self._get_feature_vector_attention_mask( extract_features.shape[1], attention_mask, add_adapter=False ) hidden_states, extract_features = self.feature_projection(extract_features) hidden_states = self._mask_hidden_states( hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask ) encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = encoder_outputs[0] if self.adapter is not None: hidden_states = self.adapter(hidden_states) if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] return Wav2Vec2BaseModelOutput( last_hidden_state=hidden_states, extract_features=extract_features, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class AngularLinear(nn.Module): def __init__(self, in_features: int, out_features: int): super(AngularLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = torch.nn.Parameter( torch.FloatTensor(out_features, in_features), requires_grad=True ) nn.init.xavier_normal_(self.weight, gain=1) def forward( self, inputs: torch.Tensor, ): # Calculation of cos(theta) cosine = F.linear(F.normalize(inputs), F.normalize(self.weight)) return cosine def extra_repr(self) -> str: return 'in_features={}, out_features={}'.format( self.in_features, self.out_features ) class AMSoftmaxLoss(nn.Module): """Additive Margin Softmax (CosFace). Paper: Wang, Feng, et al. "Additive margin softmax for face verification." IEEE Signal Processing Letters 25.7 (2018): 926-930. """ def __init__( self, scale: float = 30.0, margin: float = 0.35, label_smoothing: float = 0.0, reduction: str = "mean" ): """ Args: num_classes: Number of classes (output dimension) scale: Scaling factor for logits (default: 30.0) margin: Angular margin (default: 0.35) """ super(AMSoftmaxLoss, self).__init__() self.scale = scale self.margin = margin self.label_smoothing = label_smoothing self.reduction = reduction def forward( self, inputs: torch.Tensor, targets: torch.Tensor, ): """ Args: inputs: Input features of shape (batch_size, num_labels) targets: Ground truth labels of shape (batch_size) label_smoothing: Label smoothing factor (default: 0.0) reduction: Reduction method (default: "mean") Returns: Loss value """ _, num_labels = inputs.shape # `inputs` are the outputs from AngularLinear() cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7) psi = cos_theta - self.margin one_hot = nn.functional.one_hot(targets, num_labels) outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta) loss = F.cross_entropy( outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction ) return loss class AAMSoftmaxLoss(nn.Module): """Additive Angular Margin Softmax (ArcFace). Paper: Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019. """ def __init__( self, scale: float = 30.0, margin: float = 0.2, easy_margin: bool = False, label_smoothing: float = 0.0, reduction: str = "mean" ): """ Args: num_classes: Number of classes (output dimension) scale: Scaling factor for logits (default: 30.0) margin: Angular margin (default: 0.35) easy_margin: Use the easy margin loss (default: False) """ super(AAMSoftmaxLoss, self).__init__() self.scale = scale self.margin = margin self.easy_margin = easy_margin self.label_smoothing = label_smoothing self.reduction = reduction def forward( self, inputs: torch.Tensor, targets: torch.Tensor, ): """ Args: inputs: Input features of shape (batch_size, num_labels) targets: Ground truth labels of shape (batch_size) Returns: Loss value """ _, num_labels = inputs.shape # `inputs` are the outputs from AngularLinear() epsilon = 1e-6 # theta = torch.acos(cos_theta) # psi = torch.cos(theta + self.margin) cos_theta = torch.clamp(inputs, -1.0 + epsilon, 1.0 - epsilon) sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2)) sin_theta = torch.clamp(sin_theta, 0.0 + epsilon, 1.0 - epsilon) cos_m = math.cos(self.margin) sin_m = math.sin(self.margin) psi = cos_theta * cos_m - sin_theta * sin_m # cos(theta + m) if self.easy_margin: psi = torch.where(cos_theta > 0, psi, cos_theta) else: # Make the function cos(theta+m) monotonic decreasing while theta in [0°, 180°] psi = torch.where((cos_theta - math.cos(math.pi - self.margin)) > 0, psi, cos_theta - self.margin) one_hot = nn.functional.one_hot(targets, num_labels) outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta) loss = F.cross_entropy( outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction ) return loss class WavLMSpkRegForSequenceClassification(WavLMSpkRegPreTrainedModel): def __init__(self, config): super().__init__(config) if hasattr(config, "add_adapter") and config.add_adapter: raise ValueError( "Sequence classification does not support the use of WavLM adapters (config.add_adapter=True)" ) self.wavlm = WavLMSpkRegModel(config) num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings if config.use_weighted_layer_sum: self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) if self.config.loss_fct == 'cross_entropy': self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) elif self.config.loss_fct == 'additive_margin': self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels) elif self.config.loss_fct == 'additive_angular_margin': self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels) else: raise ValueError(f"Unsupported loss function: {self.config.loss_fct}") # Initialize weights and apply final processing self.post_init() # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor def freeze_feature_extractor(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameters will not be updated during training. """ warnings.warn( "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. " "Please use the equivalent `freeze_feature_encoder` method instead.", FutureWarning, ) self.freeze_feature_encoder() # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wavlm def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will not be updated during training. """ self.wavlm.feature_extractor._freeze_parameters() # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->wavlm def freeze_base_model(self): """ Calling this function will disable the gradient computation for the base model so that its parameters will not be updated during training. Only the classification head will be updated. """ for param in self.wavlm.parameters(): param.requires_grad = False # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->WavLM, wav2vec2->wavlm def forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, ) -> Union[Tuple, SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states outputs = self.wavlm( input_values, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.config.use_weighted_layer_sum: hidden_states = outputs[_HIDDEN_STATES_START_POSITION] hidden_states = torch.stack(hidden_states, dim=1) norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) else: hidden_states = outputs[0] hidden_states = self.projector(hidden_states) if attention_mask is None: pooled_output = hidden_states.mean(dim=1) else: padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) hidden_states[~padding_mask] = 0.0 pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) logits = self.classifier(pooled_output) loss = None if labels is not None: if self.config.loss_fct == 'cross_entropy': loss_fct = nn.CrossEntropyLoss( label_smoothing=self.config.label_smoothing, reduction=self.config.reduction ) elif self.config.loss_fct == 'additive_margin': loss_fct = AMSoftmaxLoss( scale=self.config.scale, margin=self.config.margin, label_smoothing=self.config.label_smoothing, reduction=self.config.reduction ) elif self.config.loss_fct == 'additive_angular_margin': loss_fct = AAMSoftmaxLoss( scale=self.config.scale, margin=self.config.margin, easy_margin=self.config.easy_margin, label_smoothing=self.config.label_smoothing, reduction=self.config.reduction ) loss = loss_fct( logits.view(-1, self.config.num_labels), labels.view(-1), ) if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )