Spaces:
Running
on
Zero
Running
on
Zero
from transformers import Wav2Vec2Config, Wav2Vec2Model | |
from transformers.modeling_outputs import BaseModelOutput | |
from .torch_utils import linear_interpolation | |
# the implementation of Wav2Vec2Model is borrowed from | |
# https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py | |
# initialize our encoder with the pre-trained wav2vec 2.0 weights. | |
class Wav2Vec2Model(Wav2Vec2Model): | |
def __init__(self, config: Wav2Vec2Config): | |
super().__init__(config) | |
def forward( | |
self, | |
input_values, | |
seq_len, | |
attention_mask=None, | |
mask_time_indices=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
self.config.output_attentions = True | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
extract_features = self.feature_extractor(input_values) | |
extract_features = extract_features.transpose(1, 2) | |
extract_features = linear_interpolation(extract_features, seq_len=seq_len) | |
if attention_mask is not None: | |
# compute reduced attention_mask corresponding to feature vectors | |
attention_mask = self._get_feature_vector_attention_mask( | |
extract_features.shape[1], attention_mask, add_adapter=False | |
) | |
hidden_states, extract_features = self.feature_projection(extract_features) | |
hidden_states = self._mask_hidden_states( | |
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask | |
) | |
encoder_outputs = self.encoder( | |
hidden_states, | |
attention_mask=attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = encoder_outputs[0] | |
if self.adapter is not None: | |
hidden_states = self.adapter(hidden_states) | |
if not return_dict: | |
return (hidden_states, ) + encoder_outputs[1:] | |
return BaseModelOutput( | |
last_hidden_state=hidden_states, | |
hidden_states=encoder_outputs.hidden_states, | |
attentions=encoder_outputs.attentions, | |
) | |
def feature_extract( | |
self, | |
input_values, | |
seq_len, | |
): | |
extract_features = self.feature_extractor(input_values) | |
extract_features = extract_features.transpose(1, 2) | |
extract_features = linear_interpolation(extract_features, seq_len=seq_len) | |
return extract_features | |
def encode( | |
self, | |
extract_features, | |
attention_mask=None, | |
mask_time_indices=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
self.config.output_attentions = True | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if attention_mask is not None: | |
# compute reduced attention_mask corresponding to feature vectors | |
attention_mask = self._get_feature_vector_attention_mask( | |
extract_features.shape[1], attention_mask, add_adapter=False | |
) | |
hidden_states, extract_features = self.feature_projection(extract_features) | |
hidden_states = self._mask_hidden_states( | |
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask | |
) | |
encoder_outputs = self.encoder( | |
hidden_states, | |
attention_mask=attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = encoder_outputs[0] | |
if self.adapter is not None: | |
hidden_states = self.adapter(hidden_states) | |
if not return_dict: | |
return (hidden_states, ) + encoder_outputs[1:] | |
return BaseModelOutput( | |
last_hidden_state=hidden_states, | |
hidden_states=encoder_outputs.hidden_states, | |
attentions=encoder_outputs.attentions, | |
) | |