File size: 6,949 Bytes
a1c1315 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
from typing import Optional, Union
import jax
import jax.numpy as jnp
import flax.linen as nn
from transformers.modeling_flax_outputs import FlaxCausalLMOutput
from transformers.models.wav2vec2.configuration_wav2vec2 import Wav2Vec2Config
from transformers.models.wav2vec2.modeling_flax_wav2vec2 import (
FlaxWav2Vec2FeatureEncoder,
FlaxWav2Vec2FeatureProjection,
FlaxWav2Vec2StableLayerNormEncoder,
FlaxWav2Vec2Adapter,
FlaxWav2Vec2PreTrainedModel,
FlaxWav2Vec2BaseModelOutput,
)
class FlaxWav2Vec2Module(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype)
self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype)
if self.config.mask_time_prob > 0.0 or self.config.mask_feature_prob > 0.0:
self.masked_spec_embed = self.param(
"masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,)
)
if self.config.do_stable_layer_norm:
self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype)
else:
raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.")
self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None
def __call__(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
deterministic=True,
output_attentions=None,
output_hidden_states=None,
freeze_feature_encoder=False,
return_dict=None,
):
extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder)
# make sure that no loss is computed on padded inputs
if attention_mask is not None:
# compute reduced attention_mask corresponding to feature vectors
attention_mask = self._get_feature_vector_attention_mask(
extract_features.shape[1], attention_mask, add_adapter=False
)
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
hidden_states = jnp.where(
jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
hidden_states,
)
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = encoder_outputs[0]
if self.adapter is not None:
hidden_states = self.adapter(hidden_states)
if not return_dict:
return (hidden_states, extract_features) + encoder_outputs[1:]
return FlaxWav2Vec2BaseModelOutput(
last_hidden_state=hidden_states,
extract_features=extract_features,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def _get_feat_extract_output_lengths(
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
):
"""
Computes the output length of the convolutional layers
"""
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return (input_length - kernel_size) // stride + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
if add_adapter:
for _ in range(self.config.num_adapter_layers):
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
return input_lengths
def _get_feature_vector_attention_mask(
self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None
):
# Effectively attention_mask.sum(-1), but not inplace to be able to run
# on inference mode.
non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1]
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
batch_size = attention_mask.shape[0]
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
return attention_mask
class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2Module
class FlaxWav2Vec2ForAudioFrameClassificationModule(nn.Module):
config: Wav2Vec2Config
dtype: jnp.dtype = jnp.float32
def setup(self):
self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
self.classifier = nn.Dense(
self.config.num_labels,
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
dtype=self.dtype,
)
def __call__(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
deterministic=True,
output_attentions=None,
output_hidden_states=None,
freeze_feature_encoder=False,
return_dict=None,
):
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
mask_time_indices=mask_time_indices,
deterministic=deterministic,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
freeze_feature_encoder=freeze_feature_encoder,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.classifier(hidden_states)
if not return_dict:
return (logits,) + outputs[2:]
return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
class FlaxWav2Vec2ForAudioFrameClassification(FlaxWav2Vec2PreTrainedModel):
module_class = FlaxWav2Vec2ForAudioFrameClassificationModule
|