yangwang825 commited on
Commit
8d8de54
·
verified ·
1 Parent(s): 2084d86

Upload model

Browse files
Files changed (3) hide show
  1. config.json +4 -3
  2. model.safetensors +3 -0
  3. modeling_wavlm_spkreg.py +642 -0
config.json CHANGED
@@ -1,16 +1,17 @@
1
  {
2
- "_name_or_path": "./wavlm-base",
3
  "activation_dropout": 0.0,
4
  "adapter_kernel_size": 3,
5
  "adapter_stride": 2,
6
  "add_adapter": false,
7
  "apply_spec_augment": true,
8
  "architectures": [
9
- "WavLMModel"
10
  ],
11
  "attention_dropout": 0.1,
12
  "auto_map": {
13
- "AutoConfig": "configuration_wavlm_spkreg.WavLMSpkRegConfig"
 
14
  },
15
  "bos_token_id": 1,
16
  "classifier_proj_size": 256,
 
1
  {
2
+ "_name_or_path": "microsoft/wavlm-base",
3
  "activation_dropout": 0.0,
4
  "adapter_kernel_size": 3,
5
  "adapter_stride": 2,
6
  "add_adapter": false,
7
  "apply_spec_augment": true,
8
  "architectures": [
9
+ "WavLMSpkRegModel"
10
  ],
11
  "attention_dropout": 0.1,
12
  "auto_map": {
13
+ "AutoConfig": "configuration_wavlm_spkreg.WavLMSpkRegConfig",
14
+ "AutoModel": "modeling_wavlm_spkreg.WavLMSpkRegModel"
15
  },
16
  "bos_token_id": 1,
17
  "classifier_proj_size": 256,
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de4d099ee3802e69d818b92dceb88dfa5f4980dd8726b3739a948c8859307cb9
3
+ size 377555872
modeling_wavlm_spkreg.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from typing import Union, Tuple, Optional
4
+
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.modeling_outputs import SequenceClassifierOutput, Wav2Vec2BaseModelOutput
13
+ from transformers.models.wavlm.modeling_wavlm import (
14
+ WavLMGumbelVectorQuantizer,
15
+ WavLMPositionalConvEmbedding,
16
+ WavLMFeatureProjection,
17
+ WavLMFeatureEncoder,
18
+ WavLMEncoderStableLayerNorm,
19
+ WavLMEncoder,
20
+ WavLMAdapter,
21
+ _HIDDEN_STATES_START_POSITION
22
+ )
23
+
24
+ from .configuration_wavlm_spkreg import WavLMSpkRegConfig
25
+
26
+
27
+ def _compute_mask_indices(
28
+ shape: Tuple[int, int],
29
+ mask_prob: float,
30
+ mask_length: int,
31
+ attention_mask: Optional[torch.LongTensor] = None,
32
+ min_masks: int = 0,
33
+ ) -> np.ndarray:
34
+ """
35
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
36
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
37
+ CPU as part of the preprocessing during training.
38
+
39
+ Args:
40
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
41
+ the first element is the batch size and the second element is the length of the axis to span.
42
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
43
+ independently generated mask spans of length `mask_length` is computed by
44
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
45
+ actual percentage will be smaller.
46
+ mask_length: size of the mask
47
+ min_masks: minimum number of masked spans
48
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
49
+ each batch dimension.
50
+ """
51
+ batch_size, sequence_length = shape
52
+
53
+ if mask_length < 1:
54
+ raise ValueError("`mask_length` has to be bigger than 0.")
55
+
56
+ if mask_length > sequence_length:
57
+ raise ValueError(
58
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
59
+ f" and `sequence_length`: {sequence_length}`"
60
+ )
61
+
62
+ # epsilon is used for probabilistic rounding
63
+ epsilon = np.random.rand(1).item()
64
+
65
+ def compute_num_masked_span(input_length):
66
+ """Given input length, compute how many spans should be masked"""
67
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
68
+ num_masked_span = max(num_masked_span, min_masks)
69
+
70
+ # make sure num masked span <= sequence_length
71
+ if num_masked_span * mask_length > sequence_length:
72
+ num_masked_span = sequence_length // mask_length
73
+
74
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
75
+ if input_length - (mask_length - 1) < num_masked_span:
76
+ num_masked_span = max(input_length - (mask_length - 1), 0)
77
+
78
+ return num_masked_span
79
+
80
+ # compute number of masked spans in batch
81
+ input_lengths = (
82
+ attention_mask.sum(-1).detach().tolist()
83
+ if attention_mask is not None
84
+ else [sequence_length for _ in range(batch_size)]
85
+ )
86
+
87
+ # SpecAugment mask to fill
88
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
89
+ spec_aug_mask_idxs = []
90
+
91
+ max_num_masked_span = compute_num_masked_span(sequence_length)
92
+
93
+ if max_num_masked_span == 0:
94
+ return spec_aug_mask
95
+
96
+ for input_length in input_lengths:
97
+ # compute num of masked spans for this input
98
+ num_masked_span = compute_num_masked_span(input_length)
99
+
100
+ # get random indices to mask
101
+ spec_aug_mask_idx = np.random.choice(
102
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
103
+ )
104
+
105
+ # pick first sampled index that will serve as a dummy index to pad vector
106
+ # to ensure same dimension for all batches due to probabilistic rounding
107
+ # Picking first sample just pads those vectors twice.
108
+ if len(spec_aug_mask_idx) == 0:
109
+ # this case can only happen if `input_length` is strictly smaller then
110
+ # `sequence_length` in which case the last token has to be a padding
111
+ # token which we can use as a dummy mask id
112
+ dummy_mask_idx = sequence_length - 1
113
+ else:
114
+ dummy_mask_idx = spec_aug_mask_idx[0]
115
+
116
+ spec_aug_mask_idx = np.concatenate(
117
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
118
+ )
119
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
120
+
121
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
122
+
123
+ # expand masked indices to masked spans
124
+ spec_aug_mask_idxs = np.broadcast_to(
125
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
126
+ )
127
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
128
+
129
+ # add offset to the starting indexes so that indexes now create a span
130
+ offsets = np.arange(mask_length)[None, None, :]
131
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
132
+ batch_size, max_num_masked_span * mask_length
133
+ )
134
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
135
+
136
+ # ensure that we cannot have indices larger than sequence_length
137
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
138
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
139
+
140
+ # scatter indices to mask
141
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
142
+
143
+ return spec_aug_mask
144
+
145
+
146
+ class WavLMSpkRegPreTrainedModel(PreTrainedModel):
147
+ """
148
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
149
+ models.
150
+ """
151
+
152
+ config_class = WavLMSpkRegConfig
153
+ base_model_prefix = "wavlm"
154
+ main_input_name = "input_values"
155
+ supports_gradient_checkpointing = True
156
+
157
+ def _init_weights(self, module):
158
+ """Initialize the weights"""
159
+ # gumbel softmax requires special init
160
+ if isinstance(module, WavLMGumbelVectorQuantizer):
161
+ module.weight_proj.weight.data.normal_(mean=0.0, std=1)
162
+ module.weight_proj.bias.data.zero_()
163
+ nn.init.uniform_(module.codevectors)
164
+ elif isinstance(module, WavLMPositionalConvEmbedding):
165
+ nn.init.normal_(
166
+ module.conv.weight,
167
+ mean=0,
168
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
169
+ )
170
+ nn.init.constant_(module.conv.bias, 0)
171
+ elif isinstance(module, WavLMFeatureProjection):
172
+ k = math.sqrt(1 / module.projection.in_features)
173
+ nn.init.uniform_(module.projection.weight, a=-k, b=k)
174
+ nn.init.uniform_(module.projection.bias, a=-k, b=k)
175
+ elif isinstance(module, nn.Linear):
176
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
177
+
178
+ if module.bias is not None:
179
+ module.bias.data.zero_()
180
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
181
+ module.bias.data.zero_()
182
+ module.weight.data.fill_(1.0)
183
+ elif isinstance(module, nn.Conv1d):
184
+ nn.init.kaiming_normal_(module.weight)
185
+
186
+ if module.bias is not None:
187
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
188
+ nn.init.uniform_(module.bias, a=-k, b=k)
189
+
190
+ def _get_feat_extract_output_lengths(
191
+ self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
192
+ ):
193
+ """
194
+ Computes the output length of the convolutional layers
195
+ """
196
+
197
+ add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
198
+
199
+ def _conv_out_length(input_length, kernel_size, stride):
200
+ # 1D convolutional layer output length formula taken
201
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
202
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
203
+
204
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
205
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
206
+
207
+ if add_adapter:
208
+ for _ in range(self.config.num_adapter_layers):
209
+ input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
210
+
211
+ return input_lengths
212
+
213
+ def _get_feature_vector_attention_mask(
214
+ self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
215
+ ):
216
+ # Effectively attention_mask.sum(-1), but not inplace to be able to run
217
+ # on inference mode.
218
+ non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
219
+
220
+ output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
221
+ output_lengths = output_lengths.to(torch.long)
222
+
223
+ batch_size = attention_mask.shape[0]
224
+
225
+ attention_mask = torch.zeros(
226
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
227
+ )
228
+ # these two operations makes sure that all values before the output lengths idxs are attended to
229
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
230
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
231
+ return attention_mask
232
+
233
+
234
+ class WavLMSpkRegModel(WavLMSpkRegPreTrainedModel):
235
+
236
+ def __init__(self, config: WavLMSpkRegConfig):
237
+ super().__init__(config)
238
+ self.config = config
239
+ self.feature_extractor = WavLMFeatureEncoder(config)
240
+ self.feature_projection = WavLMFeatureProjection(config)
241
+
242
+ # model only needs masking vector if mask prob is > 0.0
243
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
244
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
245
+
246
+ if config.do_stable_layer_norm:
247
+ self.encoder = WavLMEncoderStableLayerNorm(config)
248
+ else:
249
+ self.encoder = WavLMEncoder(config)
250
+
251
+ self.adapter = WavLMAdapter(config) if config.add_adapter else None
252
+
253
+ # Initialize weights and apply final processing
254
+ self.post_init()
255
+
256
+ def freeze_feature_extractor(self):
257
+ """
258
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
259
+ not be updated during training.
260
+ """
261
+ warnings.warn(
262
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
263
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
264
+ FutureWarning,
265
+ )
266
+ self.freeze_feature_encoder()
267
+
268
+ def freeze_feature_encoder(self):
269
+ """
270
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
271
+ not be updated during training.
272
+ """
273
+ self.feature_extractor._freeze_parameters()
274
+
275
+ def _mask_hidden_states(
276
+ self,
277
+ hidden_states: torch.FloatTensor,
278
+ mask_time_indices: Optional[torch.FloatTensor] = None,
279
+ attention_mask: Optional[torch.LongTensor] = None,
280
+ ):
281
+ """
282
+ Masks extracted features along time axis and/or along feature axis according to
283
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
284
+ """
285
+
286
+ # `config.apply_spec_augment` can set masking to False
287
+ if not getattr(self.config, "apply_spec_augment", True):
288
+ return hidden_states
289
+
290
+ # generate indices & apply SpecAugment along time axis
291
+ batch_size, sequence_length, hidden_size = hidden_states.size()
292
+
293
+ if mask_time_indices is not None:
294
+ # apply SpecAugment along time axis with given mask_time_indices
295
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
296
+ elif self.config.mask_time_prob > 0 and self.training:
297
+ mask_time_indices = _compute_mask_indices(
298
+ (batch_size, sequence_length),
299
+ mask_prob=self.config.mask_time_prob,
300
+ mask_length=self.config.mask_time_length,
301
+ attention_mask=attention_mask,
302
+ min_masks=self.config.mask_time_min_masks,
303
+ )
304
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
305
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
306
+
307
+ if self.config.mask_feature_prob > 0 and self.training:
308
+ # generate indices & apply SpecAugment along feature axis
309
+ mask_feature_indices = _compute_mask_indices(
310
+ (batch_size, hidden_size),
311
+ mask_prob=self.config.mask_feature_prob,
312
+ mask_length=self.config.mask_feature_length,
313
+ min_masks=self.config.mask_feature_min_masks,
314
+ )
315
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
316
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
317
+ hidden_states[mask_feature_indices] = 0
318
+
319
+ return hidden_states
320
+
321
+ def forward(
322
+ self,
323
+ input_values: Optional[torch.Tensor],
324
+ attention_mask: Optional[torch.Tensor] = None,
325
+ mask_time_indices: Optional[torch.FloatTensor] = None,
326
+ output_attentions: Optional[bool] = None,
327
+ output_hidden_states: Optional[bool] = None,
328
+ return_dict: Optional[bool] = None,
329
+ ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
330
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
331
+ output_hidden_states = (
332
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
333
+ )
334
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
335
+
336
+ extract_features = self.feature_extractor(input_values)
337
+ extract_features = extract_features.transpose(1, 2)
338
+
339
+ if attention_mask is not None:
340
+ # compute reduced attention_mask corresponding to feature vectors
341
+ attention_mask = self._get_feature_vector_attention_mask(
342
+ extract_features.shape[1], attention_mask, add_adapter=False
343
+ )
344
+
345
+ hidden_states, extract_features = self.feature_projection(extract_features)
346
+ hidden_states = self._mask_hidden_states(
347
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
348
+ )
349
+
350
+ encoder_outputs = self.encoder(
351
+ hidden_states,
352
+ attention_mask=attention_mask,
353
+ output_attentions=output_attentions,
354
+ output_hidden_states=output_hidden_states,
355
+ return_dict=return_dict,
356
+ )
357
+
358
+ hidden_states = encoder_outputs[0]
359
+
360
+ if self.adapter is not None:
361
+ hidden_states = self.adapter(hidden_states)
362
+
363
+ if not return_dict:
364
+ return (hidden_states, extract_features) + encoder_outputs[1:]
365
+
366
+ return Wav2Vec2BaseModelOutput(
367
+ last_hidden_state=hidden_states,
368
+ extract_features=extract_features,
369
+ hidden_states=encoder_outputs.hidden_states,
370
+ attentions=encoder_outputs.attentions,
371
+ )
372
+
373
+
374
+ class AngularLinear(nn.Module):
375
+
376
+ def __init__(self, in_features: int, out_features: int):
377
+ super(AngularLinear, self).__init__()
378
+ self.in_features = in_features
379
+ self.out_features = out_features
380
+ self.weight = torch.nn.Parameter(
381
+ torch.FloatTensor(out_features, in_features), requires_grad=True
382
+ )
383
+ nn.init.xavier_normal_(self.weight, gain=1)
384
+
385
+ def forward(
386
+ self,
387
+ inputs: torch.Tensor,
388
+ ):
389
+ # Calculation of cos(theta)
390
+ cosine = F.linear(F.normalize(inputs), F.normalize(self.weight))
391
+ return cosine
392
+
393
+ def extra_repr(self) -> str:
394
+ return 'in_features={}, out_features={}'.format(
395
+ self.in_features, self.out_features
396
+ )
397
+
398
+
399
+ class AMSoftmaxLoss(nn.Module):
400
+ """Additive Margin Softmax (CosFace).
401
+
402
+ Paper: Wang, Feng, et al. "Additive margin softmax for face verification."
403
+ IEEE Signal Processing Letters 25.7 (2018): 926-930.
404
+ """
405
+ def __init__(
406
+ self,
407
+ scale: float = 30.0,
408
+ margin: float = 0.35,
409
+ label_smoothing: float = 0.0,
410
+ reduction: str = "mean"
411
+ ):
412
+ """
413
+ Args:
414
+ num_classes: Number of classes (output dimension)
415
+ scale: Scaling factor for logits (default: 30.0)
416
+ margin: Angular margin (default: 0.35)
417
+ """
418
+ super(AMSoftmaxLoss, self).__init__()
419
+ self.scale = scale
420
+ self.margin = margin
421
+ self.label_smoothing = label_smoothing
422
+ self.reduction = reduction
423
+
424
+ def forward(
425
+ self,
426
+ inputs: torch.Tensor,
427
+ targets: torch.Tensor,
428
+ ):
429
+ """
430
+ Args:
431
+ inputs: Input features of shape (batch_size, num_labels)
432
+ targets: Ground truth labels of shape (batch_size)
433
+ label_smoothing: Label smoothing factor (default: 0.0)
434
+ reduction: Reduction method (default: "mean")
435
+ Returns:
436
+ Loss value
437
+ """
438
+ _, num_labels = inputs.shape
439
+ # `inputs` are the outputs from AngularLinear()
440
+ cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7)
441
+ psi = cos_theta - self.margin
442
+ one_hot = nn.functional.one_hot(targets, num_labels)
443
+ outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
444
+ loss = F.cross_entropy(
445
+ outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
446
+ )
447
+ return loss
448
+
449
+
450
+ class AAMSoftmaxLoss(nn.Module):
451
+ """Additive Angular Margin Softmax (ArcFace).
452
+
453
+ Paper: Deng, Jiankang, et al. "Arcface: Additive angular margin loss for deep face recognition."
454
+ Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.
455
+ """
456
+ def __init__(
457
+ self,
458
+ scale: float = 30.0,
459
+ margin: float = 0.35,
460
+ easy_margin: bool = False,
461
+ label_smoothing: float = 0.0,
462
+ reduction: str = "mean"
463
+ ):
464
+ """
465
+ Args:
466
+ num_classes: Number of classes (output dimension)
467
+ scale: Scaling factor for logits (default: 30.0)
468
+ margin: Angular margin (default: 0.35)
469
+ easy_margin: Use the easy margin loss (default: False)
470
+ """
471
+ super(AAMSoftmaxLoss, self).__init__()
472
+ self.scale = scale
473
+ self.margin = margin
474
+ self.easy_margin = easy_margin
475
+ self.label_smoothing = label_smoothing
476
+ self.reduction = reduction
477
+
478
+ def forward(
479
+ self,
480
+ inputs: torch.Tensor,
481
+ targets: torch.Tensor,
482
+ ):
483
+ """
484
+ Args:
485
+ inputs: Input features of shape (batch_size, num_labels)
486
+ targets: Ground truth labels of shape (batch_size)
487
+ Returns:
488
+ Loss value
489
+ """
490
+ _, num_labels = inputs.shape
491
+ # `inputs` are the outputs from AngularLinear()
492
+ cos_theta = torch.clamp(inputs, -1.0 + 1e-7, 1.0 - 1e-7)
493
+ theta = torch.acos(cos_theta)
494
+ psi = torch.cos(theta + self.margin)
495
+ one_hot = nn.functional.one_hot(targets, num_labels)
496
+ outputs = self.scale * torch.where(one_hot.bool(), psi, cos_theta)
497
+ loss = F.cross_entropy(
498
+ outputs, targets, label_smoothing=self.label_smoothing, reduction=self.reduction
499
+ )
500
+ return loss
501
+
502
+
503
+ class WavLMSpkRegForSequenceClassification(WavLMSpkRegPreTrainedModel):
504
+
505
+ def __init__(self, config):
506
+ super().__init__(config)
507
+
508
+ if hasattr(config, "add_adapter") and config.add_adapter:
509
+ raise ValueError(
510
+ "Sequence classification does not support the use of WavLM adapters (config.add_adapter=True)"
511
+ )
512
+ self.wavlm = WavLMSpkRegModel(config)
513
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
514
+ if config.use_weighted_layer_sum:
515
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
516
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
517
+
518
+ if self.config.loss_fct == 'cross_entropy':
519
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
520
+ elif self.config.loss_fct == 'additive_margin':
521
+ self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
522
+ elif self.config.loss_fct == 'additive_angular_margin':
523
+ self.classifier = AngularLinear(config.classifier_proj_size, config.num_labels)
524
+ else:
525
+ raise ValueError(f"Unsupported loss function: {self.config.loss_fct}")
526
+
527
+ # Initialize weights and apply final processing
528
+ self.post_init()
529
+
530
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor
531
+ def freeze_feature_extractor(self):
532
+ """
533
+ Calling this function will disable the gradient computation for the feature encoder so that its parameters will
534
+ not be updated during training.
535
+ """
536
+ warnings.warn(
537
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
538
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
539
+ FutureWarning,
540
+ )
541
+ self.freeze_feature_encoder()
542
+
543
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wavlm
544
+ def freeze_feature_encoder(self):
545
+ """
546
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
547
+ not be updated during training.
548
+ """
549
+ self.wavlm.feature_extractor._freeze_parameters()
550
+
551
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->wavlm
552
+ def freeze_base_model(self):
553
+ """
554
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
555
+ be updated during training. Only the classification head will be updated.
556
+ """
557
+ for param in self.wavlm.parameters():
558
+ param.requires_grad = False
559
+
560
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->WavLM, wav2vec2->wavlm
561
+ def forward(
562
+ self,
563
+ input_values: Optional[torch.Tensor],
564
+ attention_mask: Optional[torch.Tensor] = None,
565
+ output_attentions: Optional[bool] = None,
566
+ output_hidden_states: Optional[bool] = None,
567
+ return_dict: Optional[bool] = None,
568
+ labels: Optional[torch.Tensor] = None,
569
+ ) -> Union[Tuple, SequenceClassifierOutput]:
570
+ r"""
571
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
572
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
573
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
574
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
575
+ """
576
+
577
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
578
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
579
+
580
+ outputs = self.wavlm(
581
+ input_values,
582
+ attention_mask=attention_mask,
583
+ output_attentions=output_attentions,
584
+ output_hidden_states=output_hidden_states,
585
+ return_dict=return_dict,
586
+ )
587
+
588
+ if self.config.use_weighted_layer_sum:
589
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
590
+ hidden_states = torch.stack(hidden_states, dim=1)
591
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
592
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
593
+ else:
594
+ hidden_states = outputs[0]
595
+
596
+ hidden_states = self.projector(hidden_states)
597
+ if attention_mask is None:
598
+ pooled_output = hidden_states.mean(dim=1)
599
+ else:
600
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
601
+ hidden_states[~padding_mask] = 0.0
602
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
603
+
604
+ logits = self.classifier(pooled_output)
605
+
606
+ loss = None
607
+ if labels is not None:
608
+ if self.config.loss_fct == 'cross_entropy':
609
+ loss_fct = nn.CrossEntropyLoss(
610
+ label_smoothing=self.config.label_smoothing,
611
+ reduction=self.config.reduction
612
+ )
613
+ elif self.config.loss_fct == 'additive_margin':
614
+ loss_fct = AMSoftmaxLoss(
615
+ scale=self.config.scale,
616
+ margin=self.config.margin,
617
+ label_smoothing=self.config.label_smoothing,
618
+ reduction=self.config.reduction
619
+ )
620
+ elif self.config.loss_fct == 'additive_angular_margin':
621
+ loss_fct = AAMSoftmaxLoss(
622
+ scale=self.config.scale,
623
+ margin=self.config.margin,
624
+ easy_margin=self.config.easy_margin,
625
+ label_smoothing=self.config.label_smoothing,
626
+ reduction=self.config.reduction
627
+ )
628
+ loss = loss_fct(
629
+ logits.view(-1, self.config.num_labels),
630
+ labels.view(-1),
631
+ )
632
+
633
+ if not return_dict:
634
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
635
+ return ((loss,) + output) if loss is not None else output
636
+
637
+ return SequenceClassifierOutput(
638
+ loss=loss,
639
+ logits=logits,
640
+ hidden_states=outputs.hidden_states,
641
+ attentions=outputs.attentions,
642
+ )