amiriparian Filip-Packan commited on
Commit
5bc6f15
·
verified ·
1 Parent(s): 07dfeed

Updated deprecated module (#4)

Browse files

- Updated deprecated module (04872c2473b68df24c5b0057236014881813f6ea)


Co-authored-by: Filip Packań <Filip-Packan@users.noreply.huggingface.co>

Files changed (1) hide show
  1. ExHuBERT_model.py +451 -451
ExHuBERT_model.py CHANGED
@@ -1,451 +1,451 @@
1
- from dataclasses import dataclass
2
- from typing import Optional, Tuple, Union
3
-
4
- import torch
5
- import torch.nn as nn
6
- from transformers import HubertForSequenceClassification
7
- from transformers.activations import ACT2FN
8
- from transformers.deepspeed import is_deepspeed_zero3_enabled
9
- from transformers.file_utils import ModelOutput
10
- from transformers.modeling_outputs import BaseModelOutput
11
- from transformers.models.hubert import HubertConfig
12
- from transformers.models.hubert.modeling_hubert import HubertPreTrainedModel, HubertFeatureEncoder, \
13
- HubertFeatureProjection, _compute_mask_indices, \
14
- HubertPositionalConvEmbedding, HubertAttention
15
- import torch.nn.functional as F
16
- from huggingface_hub import PyTorchModelHubMixin
17
-
18
- ######
19
- #
20
- #######
21
-
22
-
23
-
24
- _HIDDEN_STATES_START_POSITION = 1
25
-
26
- # General docstring
27
- _CONFIG_FOR_DOC = "HubertConfig"
28
-
29
- # Base docstring
30
- _CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
31
- _EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
32
-
33
- # CTC docstring
34
- _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
35
- _CTC_EXPECTED_LOSS = 22.68
36
-
37
- # Audio class docstring
38
- _SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
39
- _SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
40
- _SEQ_CLASS_EXPECTED_LOSS = 8.53
41
-
42
- HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
43
- "facebook/hubert-base-ls960",
44
- # See all Hubert models at https://huggingface.co/models?filter=hubert
45
- ]
46
-
47
-
48
- # SwiGLU function
49
- # From """GLU Variants Improve Transformer """
50
- # https://doi.org/10.48550/arXiv.2002.05202
51
- class SwiGLU(nn.Module):
52
- def forward(self, x):
53
- x, gate = x.chunk(2, dim=-1)
54
- return F.silu(gate) * x
55
-
56
-
57
- @dataclass
58
- class SpeechClassifierOutput(ModelOutput):
59
- """
60
- Speech Classifier Output dataclass
61
- """
62
- loss: Optional[torch.FloatTensor] = None
63
- logits: torch.FloatTensor = None
64
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
65
- attentions: Optional[Tuple[torch.FloatTensor]] = None
66
-
67
-
68
- class ExHuBERTFeedForward(nn.Module):
69
- def __init__(self, config):
70
- super().__init__()
71
- self.intermediate_dropout = nn.Dropout(config.activation_dropout)
72
-
73
- self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
74
- if isinstance(config.hidden_act, str):
75
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
76
- else:
77
- self.intermediate_act_fn = config.hidden_act
78
-
79
- self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
80
- self.output_dropout = nn.Dropout(config.hidden_dropout)
81
-
82
- def forward(self, hidden_states):
83
- hidden_states = self.intermediate_dense(hidden_states)
84
- hidden_states = self.intermediate_act_fn(hidden_states)
85
- hidden_states = self.intermediate_dropout(hidden_states)
86
-
87
- hidden_states = self.output_dense(hidden_states)
88
- hidden_states = self.output_dropout(hidden_states)
89
- return hidden_states
90
-
91
-
92
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->Hubert
93
- class ExHuBERTEncoderLayer(nn.Module):
94
- def __init__(self, config):
95
- super().__init__()
96
- self.attention = HubertAttention(
97
- embed_dim=config.hidden_size,
98
- num_heads=config.num_attention_heads,
99
- dropout=config.attention_dropout,
100
- is_decoder=False,
101
- )
102
- self.dropout = nn.Dropout(config.hidden_dropout)
103
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
104
- self.feed_forward = ExHuBERTFeedForward(config)
105
- self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
106
- self.gate_bb_linear = nn.Linear(config.hidden_size, config.hidden_size)
107
-
108
- def forward(
109
- self,
110
- hidden_states: torch.Tensor,
111
- attention_mask: Optional[torch.Tensor] = None,
112
- output_attentions: bool = False,
113
- ):
114
- attn_residual = hidden_states
115
- hidden_states = self.layer_norm(hidden_states)
116
- hidden_states, attn_weights, _ = self.attention(
117
- hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
118
- )
119
- hidden_states = self.dropout(hidden_states)
120
- hidden_states = attn_residual + hidden_states
121
- hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
122
-
123
- hidden_states = self.gate_bb_linear(hidden_states)
124
- outputs = (hidden_states,)
125
-
126
- if output_attentions:
127
- outputs += (attn_weights,)
128
-
129
- return outputs
130
-
131
-
132
- class ExHuBERTEncoder(nn.Module):
133
- def __init__(self, config):
134
- super().__init__()
135
- self.config = config
136
- self.pos_conv_embed = HubertPositionalConvEmbedding(config)
137
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
138
- self.dropout = nn.Dropout(config.hidden_dropout)
139
- self.layers = nn.ModuleList(
140
- [ExHuBERTEncoderLayer(config) for _ in range(config.num_hidden_layers)]
141
- )
142
- self.gradient_checkpointing = False
143
-
144
- def forward(
145
- self,
146
- hidden_states,
147
- attention_mask=None,
148
- output_attentions=False,
149
- output_hidden_states=False,
150
- return_dict=True,
151
- ):
152
- all_hidden_states = () if output_hidden_states else None
153
- all_self_attentions = () if output_attentions else None
154
-
155
- if attention_mask is not None:
156
- # make sure padded tokens are not attended to
157
- expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
158
- hidden_states[~expand_attention_mask] = 0
159
-
160
- # extend attention_mask
161
- attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
162
- attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
163
- attention_mask = attention_mask.expand(
164
- attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
165
- )
166
-
167
- position_embeddings = self.pos_conv_embed(hidden_states)
168
- hidden_states = hidden_states + position_embeddings
169
- hidden_states = self.dropout(hidden_states)
170
-
171
- deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
172
-
173
- skip = torch.zeros_like(hidden_states)
174
- skip_bool = False
175
- for layer in self.layers:
176
-
177
- if output_hidden_states:
178
- all_hidden_states = all_hidden_states + (hidden_states,)
179
-
180
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
181
- dropout_probability = torch.rand([])
182
-
183
- # skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
184
- skip_the_layer = False
185
- if not skip_the_layer or deepspeed_zero3_is_enabled:
186
- # under deepspeed zero3 all gpus must run in sync
187
- # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
188
- if self.gradient_checkpointing and self.training:
189
- # create gradient checkpointing function
190
- def create_custom_forward(module):
191
- def custom_forward(*inputs):
192
- return module(*inputs, output_attentions)
193
-
194
- return custom_forward
195
-
196
- layer_outputs = torch.utils.checkpoint.checkpoint(
197
- create_custom_forward(layer),
198
- hidden_states,
199
- attention_mask,
200
- )
201
- else:
202
- layer_outputs = layer(
203
- hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
204
- )
205
- hidden_states = layer_outputs[0]
206
-
207
- if skip_the_layer:
208
- layer_outputs = (None, None)
209
-
210
- if output_attentions:
211
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
212
- if skip_bool is True:
213
- hidden_states = hidden_states + skip
214
-
215
- skip_bool = False
216
- else:
217
- skip = hidden_states
218
- skip_bool = True
219
-
220
- hidden_states = self.layer_norm(hidden_states)
221
-
222
- if output_hidden_states:
223
- all_hidden_states = all_hidden_states + (hidden_states,)
224
-
225
- if not return_dict:
226
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
227
- return BaseModelOutput(
228
- last_hidden_state=hidden_states,
229
- hidden_states=all_hidden_states,
230
- attentions=all_self_attentions,
231
- )
232
-
233
-
234
- class ExHuBERT_model_(HubertPreTrainedModel):
235
- def __init__(self, config: HubertConfig):
236
- super().__init__(config)
237
- setattr(config, 'num_hidden_layers', 48)
238
- self.config = config
239
- self.feature_extractor = HubertFeatureEncoder(config)
240
- self.feature_projection = HubertFeatureProjection(config)
241
-
242
- if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
243
- self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
244
-
245
- self.encoder = ExHuBERTEncoder(config)
246
-
247
- # Initialize weights and apply final processing
248
- self.post_init()
249
-
250
- # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
251
- def _mask_hidden_states(
252
- self,
253
- hidden_states: torch.FloatTensor,
254
- mask_time_indices: Optional[torch.FloatTensor] = None,
255
- attention_mask: Optional[torch.LongTensor] = None,
256
- ):
257
- """
258
- Masks extracted features along time axis and/or along feature axis according to
259
- [SpecAugment](https://arxiv.org/abs/1904.08779).
260
- """
261
-
262
- # `config.apply_spec_augment` can set masking to False
263
- if not getattr(self.config, "apply_spec_augment", True):
264
- return hidden_states
265
-
266
- # generate indices & apply SpecAugment along time axis
267
- batch_size, sequence_length, hidden_size = hidden_states.size()
268
-
269
- if mask_time_indices is not None:
270
- # apply SpecAugment along time axis with given mask_time_indices
271
- hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
272
- elif self.config.mask_time_prob > 0 and self.training:
273
- mask_time_indices = _compute_mask_indices(
274
- (batch_size, sequence_length),
275
- mask_prob=self.config.mask_time_prob,
276
- mask_length=self.config.mask_time_length,
277
- attention_mask=attention_mask,
278
- min_masks=self.config.mask_time_min_masks,
279
- )
280
- mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
281
- hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
282
-
283
- if self.config.mask_feature_prob > 0 and self.training:
284
- # generate indices & apply SpecAugment along feature axis
285
- mask_feature_indices = _compute_mask_indices(
286
- (batch_size, hidden_size),
287
- mask_prob=self.config.mask_feature_prob,
288
- mask_length=self.config.mask_feature_length,
289
- min_masks=self.config.mask_feature_min_masks,
290
- )
291
- mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
292
- mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
293
- hidden_states[mask_feature_indices] = 0
294
-
295
- return hidden_states
296
-
297
- def forward(
298
- self,
299
- input_values: Optional[torch.Tensor],
300
- attention_mask: Optional[torch.Tensor] = None,
301
- mask_time_indices: Optional[torch.FloatTensor] = None,
302
- output_attentions: Optional[bool] = None,
303
- output_hidden_states: Optional[bool] = None,
304
- return_dict: Optional[bool] = None,
305
- ) -> Union[Tuple, BaseModelOutput]:
306
-
307
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
308
- output_hidden_states = (
309
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
310
- )
311
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
312
-
313
- extract_features = self.feature_extractor(input_values)
314
- extract_features = extract_features.transpose(1, 2)
315
-
316
- if attention_mask is not None:
317
- # compute reduced attention_mask corresponding to feature vectors
318
- attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
319
-
320
- hidden_states = self.feature_projection(extract_features)
321
- hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
322
-
323
- encoder_outputs = self.encoder(
324
- hidden_states,
325
- attention_mask=attention_mask,
326
- output_attentions=output_attentions,
327
- output_hidden_states=output_hidden_states,
328
- return_dict=return_dict,
329
- )
330
-
331
- hidden_states = encoder_outputs[0]
332
-
333
- if not return_dict:
334
- return (hidden_states,) + encoder_outputs[1:]
335
-
336
- return BaseModelOutput(
337
- last_hidden_state=hidden_states,
338
- hidden_states=encoder_outputs.hidden_states,
339
- attentions=encoder_outputs.attentions,
340
- )
341
-
342
-
343
- class ExHuBERT(HubertPreTrainedModel,PyTorchModelHubMixin):
344
- def __init__(self, config):
345
- super().__init__(config)
346
- setattr(config, "num_labels", 6)
347
- if hasattr(config, "add_adapter") and config.add_adapter:
348
- raise ValueError(
349
- "Sequence classification does not support the use of Hubert adapters (config.add_adapter=True)"
350
- )
351
- self.hubert = ExHuBERT_model_(config)
352
- num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
353
- if config.use_weighted_layer_sum:
354
- self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
355
- self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
356
- self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
357
-
358
- # Initialize weights and apply final processing
359
- self.post_init()
360
-
361
- def freeze_feature_encoder(self):
362
- """
363
- Calling this function will disable the gradient computation for the feature encoder so that its parameter will
364
- not be updated during training.
365
- """
366
- self.hubert.feature_extractor._freeze_parameters()
367
-
368
- def freeze_base_model(self):
369
- """
370
- Calling this function will disable the gradient computation for the base model so that its parameters will not
371
- be updated during training. Only the classification head will be updated.
372
- """
373
- for param in self.hubert.parameters():
374
- param.requires_grad = False
375
-
376
- def forward(
377
- self,
378
- input_values: Optional[torch.Tensor],
379
- attention_mask: Optional[torch.Tensor] = None,
380
- output_attentions: Optional[bool] = None,
381
- output_hidden_states: Optional[bool] = None,
382
- return_dict: Optional[bool] = None,
383
- labels: Optional[torch.Tensor] = None,
384
- ) -> Union[Tuple, SpeechClassifierOutput]:
385
- r"""
386
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
387
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
388
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
389
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
390
- """
391
-
392
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
393
- output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
394
-
395
- outputs = self.hubert(
396
- input_values,
397
- attention_mask=attention_mask,
398
- output_attentions=output_attentions,
399
- output_hidden_states=output_hidden_states,
400
- return_dict=return_dict,
401
- )
402
-
403
- if self.config.use_weighted_layer_sum:
404
- hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
405
- hidden_states = torch.stack(hidden_states, dim=1)
406
- norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
407
- hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
408
- else:
409
- hidden_states = outputs[0]
410
-
411
- hidden_states = self.projector(hidden_states)
412
- if attention_mask is None:
413
- pooled_output = hidden_states.mean(dim=1)
414
- else:
415
- padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
416
- hidden_states[~padding_mask] = 0.0
417
- pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
418
-
419
- logits = self.classifier(pooled_output)
420
-
421
- loss = None
422
-
423
- if not return_dict:
424
- output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
425
- return ((loss,) + output) if loss is not None else output
426
-
427
- return SpeechClassifierOutput(
428
- loss=loss,
429
- logits=logits,
430
- hidden_states=outputs.hidden_states,
431
- attentions=outputs.attentions,
432
- )
433
-
434
- def freeze_og_encoder(self):
435
- for param in self.hubert.encoder.layers[::2].parameters():
436
- param.requires_grad = False
437
-
438
- def print_trainable_parameters(model):
439
- '''
440
- prints all trainable parameters of a model
441
- '''
442
- trainable_params = 0
443
- all_param = 0
444
- for _, param in model.named_parameters():
445
- all_param += param.numel()
446
- if param.requires_grad:
447
- trainable_params += param.numel()
448
- print(
449
- f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.2f}"
450
- )
451
-
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import HubertForSequenceClassification
7
+ from transformers.activations import ACT2FN
8
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
9
+ from transformers.file_utils import ModelOutput
10
+ from transformers.modeling_outputs import BaseModelOutput
11
+ from transformers.models.hubert import HubertConfig
12
+ from transformers.models.hubert.modeling_hubert import HubertPreTrainedModel, HubertFeatureEncoder, \
13
+ HubertFeatureProjection, _compute_mask_indices, \
14
+ HubertPositionalConvEmbedding, HubertAttention
15
+ import torch.nn.functional as F
16
+ from huggingface_hub import PyTorchModelHubMixin
17
+
18
+ ######
19
+ #
20
+ #######
21
+
22
+
23
+
24
+ _HIDDEN_STATES_START_POSITION = 1
25
+
26
+ # General docstring
27
+ _CONFIG_FOR_DOC = "HubertConfig"
28
+
29
+ # Base docstring
30
+ _CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
31
+ _EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
32
+
33
+ # CTC docstring
34
+ _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
35
+ _CTC_EXPECTED_LOSS = 22.68
36
+
37
+ # Audio class docstring
38
+ _SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
39
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
40
+ _SEQ_CLASS_EXPECTED_LOSS = 8.53
41
+
42
+ HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
43
+ "facebook/hubert-base-ls960",
44
+ # See all Hubert models at https://huggingface.co/models?filter=hubert
45
+ ]
46
+
47
+
48
+ # SwiGLU function
49
+ # From """GLU Variants Improve Transformer """
50
+ # https://doi.org/10.48550/arXiv.2002.05202
51
+ class SwiGLU(nn.Module):
52
+ def forward(self, x):
53
+ x, gate = x.chunk(2, dim=-1)
54
+ return F.silu(gate) * x
55
+
56
+
57
+ @dataclass
58
+ class SpeechClassifierOutput(ModelOutput):
59
+ """
60
+ Speech Classifier Output dataclass
61
+ """
62
+ loss: Optional[torch.FloatTensor] = None
63
+ logits: torch.FloatTensor = None
64
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
65
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
66
+
67
+
68
+ class ExHuBERTFeedForward(nn.Module):
69
+ def __init__(self, config):
70
+ super().__init__()
71
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
72
+
73
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
74
+ if isinstance(config.hidden_act, str):
75
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
76
+ else:
77
+ self.intermediate_act_fn = config.hidden_act
78
+
79
+ self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
80
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
81
+
82
+ def forward(self, hidden_states):
83
+ hidden_states = self.intermediate_dense(hidden_states)
84
+ hidden_states = self.intermediate_act_fn(hidden_states)
85
+ hidden_states = self.intermediate_dropout(hidden_states)
86
+
87
+ hidden_states = self.output_dense(hidden_states)
88
+ hidden_states = self.output_dropout(hidden_states)
89
+ return hidden_states
90
+
91
+
92
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->Hubert
93
+ class ExHuBERTEncoderLayer(nn.Module):
94
+ def __init__(self, config):
95
+ super().__init__()
96
+ self.attention = HubertAttention(
97
+ embed_dim=config.hidden_size,
98
+ num_heads=config.num_attention_heads,
99
+ dropout=config.attention_dropout,
100
+ is_decoder=False,
101
+ )
102
+ self.dropout = nn.Dropout(config.hidden_dropout)
103
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
104
+ self.feed_forward = ExHuBERTFeedForward(config)
105
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
106
+ self.gate_bb_linear = nn.Linear(config.hidden_size, config.hidden_size)
107
+
108
+ def forward(
109
+ self,
110
+ hidden_states: torch.Tensor,
111
+ attention_mask: Optional[torch.Tensor] = None,
112
+ output_attentions: bool = False,
113
+ ):
114
+ attn_residual = hidden_states
115
+ hidden_states = self.layer_norm(hidden_states)
116
+ hidden_states, attn_weights, _ = self.attention(
117
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
118
+ )
119
+ hidden_states = self.dropout(hidden_states)
120
+ hidden_states = attn_residual + hidden_states
121
+ hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
122
+
123
+ hidden_states = self.gate_bb_linear(hidden_states)
124
+ outputs = (hidden_states,)
125
+
126
+ if output_attentions:
127
+ outputs += (attn_weights,)
128
+
129
+ return outputs
130
+
131
+
132
+ class ExHuBERTEncoder(nn.Module):
133
+ def __init__(self, config):
134
+ super().__init__()
135
+ self.config = config
136
+ self.pos_conv_embed = HubertPositionalConvEmbedding(config)
137
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
138
+ self.dropout = nn.Dropout(config.hidden_dropout)
139
+ self.layers = nn.ModuleList(
140
+ [ExHuBERTEncoderLayer(config) for _ in range(config.num_hidden_layers)]
141
+ )
142
+ self.gradient_checkpointing = False
143
+
144
+ def forward(
145
+ self,
146
+ hidden_states,
147
+ attention_mask=None,
148
+ output_attentions=False,
149
+ output_hidden_states=False,
150
+ return_dict=True,
151
+ ):
152
+ all_hidden_states = () if output_hidden_states else None
153
+ all_self_attentions = () if output_attentions else None
154
+
155
+ if attention_mask is not None:
156
+ # make sure padded tokens are not attended to
157
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
158
+ hidden_states[~expand_attention_mask] = 0
159
+
160
+ # extend attention_mask
161
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
162
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
163
+ attention_mask = attention_mask.expand(
164
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
165
+ )
166
+
167
+ position_embeddings = self.pos_conv_embed(hidden_states)
168
+ hidden_states = hidden_states + position_embeddings
169
+ hidden_states = self.dropout(hidden_states)
170
+
171
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
172
+
173
+ skip = torch.zeros_like(hidden_states)
174
+ skip_bool = False
175
+ for layer in self.layers:
176
+
177
+ if output_hidden_states:
178
+ all_hidden_states = all_hidden_states + (hidden_states,)
179
+
180
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
181
+ dropout_probability = torch.rand([])
182
+
183
+ # skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
184
+ skip_the_layer = False
185
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
186
+ # under deepspeed zero3 all gpus must run in sync
187
+ # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
188
+ if self.gradient_checkpointing and self.training:
189
+ # create gradient checkpointing function
190
+ def create_custom_forward(module):
191
+ def custom_forward(*inputs):
192
+ return module(*inputs, output_attentions)
193
+
194
+ return custom_forward
195
+
196
+ layer_outputs = torch.utils.checkpoint.checkpoint(
197
+ create_custom_forward(layer),
198
+ hidden_states,
199
+ attention_mask,
200
+ )
201
+ else:
202
+ layer_outputs = layer(
203
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
204
+ )
205
+ hidden_states = layer_outputs[0]
206
+
207
+ if skip_the_layer:
208
+ layer_outputs = (None, None)
209
+
210
+ if output_attentions:
211
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
212
+ if skip_bool is True:
213
+ hidden_states = hidden_states + skip
214
+
215
+ skip_bool = False
216
+ else:
217
+ skip = hidden_states
218
+ skip_bool = True
219
+
220
+ hidden_states = self.layer_norm(hidden_states)
221
+
222
+ if output_hidden_states:
223
+ all_hidden_states = all_hidden_states + (hidden_states,)
224
+
225
+ if not return_dict:
226
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
227
+ return BaseModelOutput(
228
+ last_hidden_state=hidden_states,
229
+ hidden_states=all_hidden_states,
230
+ attentions=all_self_attentions,
231
+ )
232
+
233
+
234
+ class ExHuBERT_model_(HubertPreTrainedModel):
235
+ def __init__(self, config: HubertConfig):
236
+ super().__init__(config)
237
+ setattr(config, 'num_hidden_layers', 48)
238
+ self.config = config
239
+ self.feature_extractor = HubertFeatureEncoder(config)
240
+ self.feature_projection = HubertFeatureProjection(config)
241
+
242
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
243
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
244
+
245
+ self.encoder = ExHuBERTEncoder(config)
246
+
247
+ # Initialize weights and apply final processing
248
+ self.post_init()
249
+
250
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
251
+ def _mask_hidden_states(
252
+ self,
253
+ hidden_states: torch.FloatTensor,
254
+ mask_time_indices: Optional[torch.FloatTensor] = None,
255
+ attention_mask: Optional[torch.LongTensor] = None,
256
+ ):
257
+ """
258
+ Masks extracted features along time axis and/or along feature axis according to
259
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
260
+ """
261
+
262
+ # `config.apply_spec_augment` can set masking to False
263
+ if not getattr(self.config, "apply_spec_augment", True):
264
+ return hidden_states
265
+
266
+ # generate indices & apply SpecAugment along time axis
267
+ batch_size, sequence_length, hidden_size = hidden_states.size()
268
+
269
+ if mask_time_indices is not None:
270
+ # apply SpecAugment along time axis with given mask_time_indices
271
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
272
+ elif self.config.mask_time_prob > 0 and self.training:
273
+ mask_time_indices = _compute_mask_indices(
274
+ (batch_size, sequence_length),
275
+ mask_prob=self.config.mask_time_prob,
276
+ mask_length=self.config.mask_time_length,
277
+ attention_mask=attention_mask,
278
+ min_masks=self.config.mask_time_min_masks,
279
+ )
280
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
281
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
282
+
283
+ if self.config.mask_feature_prob > 0 and self.training:
284
+ # generate indices & apply SpecAugment along feature axis
285
+ mask_feature_indices = _compute_mask_indices(
286
+ (batch_size, hidden_size),
287
+ mask_prob=self.config.mask_feature_prob,
288
+ mask_length=self.config.mask_feature_length,
289
+ min_masks=self.config.mask_feature_min_masks,
290
+ )
291
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
292
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
293
+ hidden_states[mask_feature_indices] = 0
294
+
295
+ return hidden_states
296
+
297
+ def forward(
298
+ self,
299
+ input_values: Optional[torch.Tensor],
300
+ attention_mask: Optional[torch.Tensor] = None,
301
+ mask_time_indices: Optional[torch.FloatTensor] = None,
302
+ output_attentions: Optional[bool] = None,
303
+ output_hidden_states: Optional[bool] = None,
304
+ return_dict: Optional[bool] = None,
305
+ ) -> Union[Tuple, BaseModelOutput]:
306
+
307
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
308
+ output_hidden_states = (
309
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
310
+ )
311
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
312
+
313
+ extract_features = self.feature_extractor(input_values)
314
+ extract_features = extract_features.transpose(1, 2)
315
+
316
+ if attention_mask is not None:
317
+ # compute reduced attention_mask corresponding to feature vectors
318
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
319
+
320
+ hidden_states = self.feature_projection(extract_features)
321
+ hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
322
+
323
+ encoder_outputs = self.encoder(
324
+ hidden_states,
325
+ attention_mask=attention_mask,
326
+ output_attentions=output_attentions,
327
+ output_hidden_states=output_hidden_states,
328
+ return_dict=return_dict,
329
+ )
330
+
331
+ hidden_states = encoder_outputs[0]
332
+
333
+ if not return_dict:
334
+ return (hidden_states,) + encoder_outputs[1:]
335
+
336
+ return BaseModelOutput(
337
+ last_hidden_state=hidden_states,
338
+ hidden_states=encoder_outputs.hidden_states,
339
+ attentions=encoder_outputs.attentions,
340
+ )
341
+
342
+
343
+ class ExHuBERT(HubertPreTrainedModel,PyTorchModelHubMixin):
344
+ def __init__(self, config):
345
+ super().__init__(config)
346
+ setattr(config, "num_labels", 6)
347
+ if hasattr(config, "add_adapter") and config.add_adapter:
348
+ raise ValueError(
349
+ "Sequence classification does not support the use of Hubert adapters (config.add_adapter=True)"
350
+ )
351
+ self.hubert = ExHuBERT_model_(config)
352
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
353
+ if config.use_weighted_layer_sum:
354
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
355
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
356
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
357
+
358
+ # Initialize weights and apply final processing
359
+ self.post_init()
360
+
361
+ def freeze_feature_encoder(self):
362
+ """
363
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
364
+ not be updated during training.
365
+ """
366
+ self.hubert.feature_extractor._freeze_parameters()
367
+
368
+ def freeze_base_model(self):
369
+ """
370
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
371
+ be updated during training. Only the classification head will be updated.
372
+ """
373
+ for param in self.hubert.parameters():
374
+ param.requires_grad = False
375
+
376
+ def forward(
377
+ self,
378
+ input_values: Optional[torch.Tensor],
379
+ attention_mask: Optional[torch.Tensor] = None,
380
+ output_attentions: Optional[bool] = None,
381
+ output_hidden_states: Optional[bool] = None,
382
+ return_dict: Optional[bool] = None,
383
+ labels: Optional[torch.Tensor] = None,
384
+ ) -> Union[Tuple, SpeechClassifierOutput]:
385
+ r"""
386
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
387
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
388
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
389
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
390
+ """
391
+
392
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
393
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
394
+
395
+ outputs = self.hubert(
396
+ input_values,
397
+ attention_mask=attention_mask,
398
+ output_attentions=output_attentions,
399
+ output_hidden_states=output_hidden_states,
400
+ return_dict=return_dict,
401
+ )
402
+
403
+ if self.config.use_weighted_layer_sum:
404
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
405
+ hidden_states = torch.stack(hidden_states, dim=1)
406
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
407
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
408
+ else:
409
+ hidden_states = outputs[0]
410
+
411
+ hidden_states = self.projector(hidden_states)
412
+ if attention_mask is None:
413
+ pooled_output = hidden_states.mean(dim=1)
414
+ else:
415
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
416
+ hidden_states[~padding_mask] = 0.0
417
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
418
+
419
+ logits = self.classifier(pooled_output)
420
+
421
+ loss = None
422
+
423
+ if not return_dict:
424
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
425
+ return ((loss,) + output) if loss is not None else output
426
+
427
+ return SpeechClassifierOutput(
428
+ loss=loss,
429
+ logits=logits,
430
+ hidden_states=outputs.hidden_states,
431
+ attentions=outputs.attentions,
432
+ )
433
+
434
+ def freeze_og_encoder(self):
435
+ for param in self.hubert.encoder.layers[::2].parameters():
436
+ param.requires_grad = False
437
+
438
+ def print_trainable_parameters(model):
439
+ '''
440
+ prints all trainable parameters of a model
441
+ '''
442
+ trainable_params = 0
443
+ all_param = 0
444
+ for _, param in model.named_parameters():
445
+ all_param += param.numel()
446
+ if param.requires_grad:
447
+ trainable_params += param.numel()
448
+ print(
449
+ f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.2f}"
450
+ )
451
+